diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 357e41c91018..591355088984 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -8,24 +8,6 @@ from torch.utils import _pytree as pytree -def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): - if isinstance(output1, torch.Tensor): - testcase.assertIsInstance(output2, torch.Tensor) - output2_cpu = output2.detach().cpu() - if output2_cpu.dtype != output1.dtype: - output2_cpu = output2_cpu.to(output1.dtype) - testcase.assertTrue( - torch.allclose( - output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan)) - elif isinstance(output1, (tuple, list)): - testcase.assertIsInstance(output2, (tuple, list)) - testcase.assertEqual(len(output1), len(output2)) - for o1, o2 in zip(output1, output2): - diff_output(testcase, o1, o2, rtol, atol) - else: - testcase.assertEqual(output1, output2) - - def run_export_and_compare(testcase, func, args, diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py index 334c63bb4239..0bb5656de5a1 100644 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ b/experimental/torch_xla2/torch_xla2/_ops.py @@ -57,12 +57,7 @@ def _aten_add(x, y): assert x.dtype == y.dtype, (x.dtype, y.dtype) """ - try: - return x + y - except Exception as e: - import pdb - - pdb.set_trace() + return x + y @op(torch.ops.aten.add_, is_jax_func=False) @@ -239,6 +234,7 @@ def _aten_div(x, y, rounding_mode=""): res = jnp.trunc(res) return res + @op(torch.ops.aten.div_, is_jax_func=False) def _aten_div_(x, y, rounding_mode=""): x._elem = _aten_div(x._elem, y._elem, rounding_mode) @@ -376,14 +372,9 @@ def _aten_ne(x, y): @op(torch.ops.aten.cumsum) def _aten_cumsum(x, y, dtype=None): - try: - dtype = tensor.t2j_dtype(dtype) - res = jnp.cumsum(x, y, dtype) - return res - except Exception as e: - import pdb - - pdb.set_trace() + dtype = tensor.t2j_dtype(dtype) + res = jnp.cumsum(x, y, dtype) + return res @op(torch.ops.aten.native_layer_norm)