Skip to content

Commit

Permalink
Add opinfo test
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Apr 3, 2024
1 parent 895b0c2 commit 0612f48
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 32 deletions.
18 changes: 0 additions & 18 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,6 @@
from torch.utils import _pytree as pytree


def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
testcase.assertTrue(
torch.allclose(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
elif isinstance(output1, (tuple, list)):
testcase.assertIsInstance(output2, (tuple, list))
testcase.assertEqual(len(output1), len(output2))
for o1, o2 in zip(output1, output2):
diff_output(testcase, o1, o2, rtol, atol)
else:
testcase.assertEqual(output1, output2)


def run_export_and_compare(testcase,
func,
args,
Expand Down
19 changes: 5 additions & 14 deletions experimental/torch_xla2/torch_xla2/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ def _aten_add(x, y):
assert x.dtype == y.dtype, (x.dtype, y.dtype)
"""
try:
return x + y
except Exception as e:
import pdb

pdb.set_trace()
return x + y


@op(torch.ops.aten.add_, is_jax_func=False)
Expand Down Expand Up @@ -239,6 +234,7 @@ def _aten_div(x, y, rounding_mode=""):
res = jnp.trunc(res)
return res


@op(torch.ops.aten.div_, is_jax_func=False)
def _aten_div_(x, y, rounding_mode=""):
x._elem = _aten_div(x._elem, y._elem, rounding_mode)
Expand Down Expand Up @@ -376,14 +372,9 @@ def _aten_ne(x, y):

@op(torch.ops.aten.cumsum)
def _aten_cumsum(x, y, dtype=None):
try:
dtype = tensor.t2j_dtype(dtype)
res = jnp.cumsum(x, y, dtype)
return res
except Exception as e:
import pdb

pdb.set_trace()
dtype = tensor.t2j_dtype(dtype)
res = jnp.cumsum(x, y, dtype)
return res


@op(torch.ops.aten.native_layer_norm)
Expand Down

0 comments on commit 0612f48

Please sign in to comment.