Skip to content

Commit

Permalink
Fix more opinfo tests (#7008)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored May 1, 2024
1 parent 2399e10 commit 2907ab3
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 22 deletions.
17 changes: 4 additions & 13 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
instantiate_device_type_tests, ops)
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
import torch_xla2


skiplist = {
"__getitem__",
Expand All @@ -15,18 +17,6 @@
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
"addmm",
"addmv",
"addr",
"all",
"allclose",
"amax",
"amin",
"aminmax",
"angle",
"any",
"argmax",
"argmin",
"argsort",
"as_strided",
"as_strided_scatter",
Expand Down Expand Up @@ -639,7 +629,8 @@ def run_export_and_compare(testcase,
input2, args2, kwargs2 = pytree.tree_map_only(
torch.Tensor, tensor.move_to_device,
(sample_input.input, sample_input.args, sample_input.kwargs))
res2 = func(input2, *args2, **kwargs2)
with torch_xla2.mode():
res2 = func(input2, *args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
with testcase.subTest("torch_xla2_diff:" + str(atol)):
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
Expand Down
45 changes: 36 additions & 9 deletions experimental/torch_xla2/torch_xla2/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,10 @@ def _aten_native_layer_norm(input,

# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
@op(torch.ops.aten.addmm)
@op(torch.ops.aten.addmv)
def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
alpha = jnp.array(alpha).astype(mat1.dtype)
beta = jnp.array(beta).astype(mat1.dtype)
self *= beta
self += alpha * jnp.matmul(mat1, mat2)
return self
Expand Down Expand Up @@ -641,13 +644,14 @@ def _aten_min(x, axis=None):


@op(torch.ops.aten.amin)
def _aten_amin(x, axis=None):
return jnp.min(x, axis=axis)
def _aten_amin(x, dim=None, keepdim=False):
return _with_reduction_scalar(jnp.amin, x, dim, keepdim)


@op(torch.ops.aten.argmin)
def _aten_amin(x, axis=None):
return jnp.argmin(x, axis=axis)
def _aten_argmin(self, dim=None, keepdim=False):
return _with_reduction_scalar(
jnp.argmin, self, dim, keepdim)


@op(torch.ops.aten.sin)
Expand Down Expand Up @@ -1211,13 +1215,27 @@ def _aten_abs(self):
# generate aten.amax only
@op(torch.ops.aten.amax)
def _aten_amax(self, dim=None, keepdim=False):
return jnp.amax(self, axis=dim, keepdims=keepdim)

return _with_reduction_scalar(jnp.amax, self, dim, keepdim)


def _with_reduction_scalar(jax_func, self, dim, keepdim):
expanded = False
if self.ndim == 0:
# for self of rank 0:
# torch.any(x, 0), torch.any(x, -1) works;
# torch.any(x, 1) throws out of bounds, so it's
# behavior is the same as a jnp array of rank 1
expanded = True
self = jnp.expand_dims(self, 0)
res = jax_func(self, axis=dim, keepdims=keepdim)
if expanded:
res = res.squeeze()
return res

# aten.any
@op(torch.ops.aten.any)
def _aten_any(self, dim=None, keepdim=False):
return jnp.any(self, axis=dim, keepdims=keepdim)
return _with_reduction_scalar(jnp.any, self, dim, keepdim)


# aten.arange
Expand Down Expand Up @@ -1246,7 +1264,8 @@ def _aten_arange(start,
# aten.argmax
@op(torch.ops.aten.argmax)
def _aten_argmax(self, dim=None, keepdim=False):
return jnp.argmax(self, axis=dim, keepdims=keepdim)
return _with_reduction_scalar(
jnp.argmax, self, dim, keepdim)


# aten.as_strided
Expand Down Expand Up @@ -1751,4 +1770,12 @@ def _aten_local_scalar_dense(x):

@op(torch.ops.aten.tensor_split.sections)
def _aten_tensor_split(ary, indices_or_sections, axis=0):
return jnp.array_split(ary, indices_or_sections, axis)
return jnp.array_split(ary, indices_or_sections, axis)

@op(torch.ops.aten.outer)
def _aten_outer(a, b):
return jnp.outer(a, b)

@op(torch.ops.aten.allclose)
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(input, other, rtol, atol, equal_nan)
26 changes: 26 additions & 0 deletions experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,32 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
# TODO: handle torch.Size
return jnp.full(size, fill_value, dtype=dtype)

@register_function(torch.allclose)
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(input, other, rtol, atol, equal_nan)

@register_function(torch.angle)
def _torch_angle(input):
return jnp.angle(input)


@register_function(torch.argsort)
def _torch_argsort(input, dim=-1, descending=False, stable=False):
expanded = False
if input == 0:
# for self of rank 0:
# torch.any(x, 0), torch.any(x, -1) works;
# torch.any(x, 1) throws out of bounds, so it's
# behavior is the same as a jnp array of rank 1
expanded = True
input = jnp.expand_dims(input, 0)
res = jnp.argsort(input, axis=dim, descending=descending,
stable=stable)
if expanded:
res = res.squeeze()
return res



class XLAFunctionMode(torch.overrides.TorchFunctionMode):
"""Context manager that dispatches torch function calls to JAX."""
Expand Down
7 changes: 7 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch



torch_ops_override = {
torch.allclose: torch.ops.aten.allclose
}

0 comments on commit 2907ab3

Please sign in to comment.