diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2a156cbce6c..5f6fdbbeab2 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -7,6 +7,8 @@ instantiate_device_type_tests, ops) from torch.utils import _pytree as pytree from torch_xla2 import tensor +import torch_xla2 + skiplist = { "__getitem__", @@ -15,18 +17,6 @@ "_native_batch_norm_legit", "_segment_reduce", "_upsample_bilinear2d_aa", - "addmm", - "addmv", - "addr", - "all", - "allclose", - "amax", - "amin", - "aminmax", - "angle", - "any", - "argmax", - "argmin", "argsort", "as_strided", "as_strided_scatter", @@ -639,7 +629,8 @@ def run_export_and_compare(testcase, input2, args2, kwargs2 = pytree.tree_map_only( torch.Tensor, tensor.move_to_device, (sample_input.input, sample_input.args, sample_input.kwargs)) - res2 = func(input2, *args2, **kwargs2) + with torch_xla2.mode(): + res2 = func(input2, *args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) with testcase.subTest("torch_xla2_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py index 0eacf2d47a3..e3650234372 100644 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ b/experimental/torch_xla2/torch_xla2/_ops.py @@ -410,7 +410,10 @@ def _aten_native_layer_norm(input, # - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor @op(torch.ops.aten.addmm) +@op(torch.ops.aten.addmv) def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) self *= beta self += alpha * jnp.matmul(mat1, mat2) return self @@ -641,13 +644,14 @@ def _aten_min(x, axis=None): @op(torch.ops.aten.amin) -def _aten_amin(x, axis=None): - return jnp.min(x, axis=axis) +def _aten_amin(x, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.amin, x, dim, keepdim) @op(torch.ops.aten.argmin) -def _aten_amin(x, axis=None): - return jnp.argmin(x, axis=axis) +def _aten_argmin(self, dim=None, keepdim=False): + return _with_reduction_scalar( + jnp.argmin, self, dim, keepdim) @op(torch.ops.aten.sin) @@ -1211,13 +1215,27 @@ def _aten_abs(self): # generate aten.amax only @op(torch.ops.aten.amax) def _aten_amax(self, dim=None, keepdim=False): - return jnp.amax(self, axis=dim, keepdims=keepdim) - + return _with_reduction_scalar(jnp.amax, self, dim, keepdim) + + +def _with_reduction_scalar(jax_func, self, dim, keepdim): + expanded = False + if self.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + self = jnp.expand_dims(self, 0) + res = jax_func(self, axis=dim, keepdims=keepdim) + if expanded: + res = res.squeeze() + return res # aten.any @op(torch.ops.aten.any) def _aten_any(self, dim=None, keepdim=False): - return jnp.any(self, axis=dim, keepdims=keepdim) + return _with_reduction_scalar(jnp.any, self, dim, keepdim) # aten.arange @@ -1246,7 +1264,8 @@ def _aten_arange(start, # aten.argmax @op(torch.ops.aten.argmax) def _aten_argmax(self, dim=None, keepdim=False): - return jnp.argmax(self, axis=dim, keepdims=keepdim) + return _with_reduction_scalar( + jnp.argmax, self, dim, keepdim) # aten.as_strided @@ -1751,4 +1770,12 @@ def _aten_local_scalar_dense(x): @op(torch.ops.aten.tensor_split.sections) def _aten_tensor_split(ary, indices_or_sections, axis=0): - return jnp.array_split(ary, indices_or_sections, axis) \ No newline at end of file + return jnp.array_split(ary, indices_or_sections, axis) + +@op(torch.ops.aten.outer) +def _aten_outer(a, b): + return jnp.outer(a, b) + +@op(torch.ops.aten.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index 9fcd5653a86..94320fd7cb2 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -92,6 +92,32 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): # TODO: handle torch.Size return jnp.full(size, fill_value, dtype=dtype) +@register_function(torch.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) + +@register_function(torch.angle) +def _torch_angle(input): + return jnp.angle(input) + + +@register_function(torch.argsort) +def _torch_argsort(input, dim=-1, descending=False, stable=False): + expanded = False + if input == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + input = jnp.expand_dims(input, 0) + res = jnp.argsort(input, axis=dim, descending=descending, + stable=stable) + if expanded: + res = res.squeeze() + return res + + class XLAFunctionMode(torch.overrides.TorchFunctionMode): """Context manager that dispatches torch function calls to JAX.""" diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index e69de29bb2d..6628b7e9510 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -0,0 +1,7 @@ +import torch + + + +torch_ops_override = { + torch.allclose: torch.ops.aten.allclose +} \ No newline at end of file