diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 5558b373b97..28d0f29f0c1 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -15,7 +15,6 @@ "bincount", # NOTE: dtype for int input torch gives float. This is weird. "byte", "cat", - "cholesky", "cholesky_solve", "diagonal_copy", "geqrf", @@ -23,8 +22,6 @@ "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. "index_reduce", "kthvalue", - "linalg.cholesky", - "linalg.cholesky_ex", "linalg.det", "linalg.ldl_solve", "linalg.lu_solve", @@ -49,13 +46,7 @@ "nn.functional.max_pool2d", "nn.functional.max_pool3d", "nn.functional.multi_head_attention_forward", - "nn.functional.multilabel_margin_loss", - "nn.functional.pairwise_distance", - "nn.functional.poisson_nll_loss", - "nn.functional.rrelu", "nn.functional.upsample_nearest", - "nonzero", - "nonzero_static", "normal", "ormqr", "pca_lowrank", @@ -67,7 +58,6 @@ "special.zeta", "unfold_copy", "unfold", - "randint", } not_support_ops_list = { @@ -77,6 +67,7 @@ "ceil", # only failed with python 3.9 "trunc", # only failed with python 3.9 "to_sparse", # We are not supporting sparse tensors yet. + "nn.functional.rrelu", # pure torch result match torch_xla2 test result, only OpInfo mismatch: https://gist.github.com/ManfeiBai/1a449b15f4e946bfcaa3e5ef86da20f4 } # These inputs are themselves views @@ -106,6 +97,7 @@ 'cauchy', 'exponential', 'log_normal', + 'randint', } atol_dict = {"linalg.eig": (2e0, 3e0), diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index c72459cbdb3..8a82052c669 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -236,6 +236,46 @@ def _aten_index_select(x, dim, index): return jnp.take(x, index, dim) +@op(torch.ops.aten.cholesky) +def _aten_cholesky(input, upper=False): + return jax.scipy.linalg.cholesky(input, lower=(not upper)) + + +@op(torch.ops.aten.linalg_cholesky_ex) +def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False): + if check_errors: + raise NotImplementedError( + "check_errors=True is not supported in this JAX implementation. " + "Check for positive definiteness using jnp.linalg.eigvalsh before " + "calling this function." + ) + + L = jax.scipy.linalg.cholesky(input, lower=not upper) + if len(L.shape) >2: + info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32) + else: + info = jnp.array(0, dtype=jnp.int32) + return L, info + + +@op(torch.ops.aten.cholesky_solve) +def _aten_cholesky_solve(input, input2, upper=False): + # Ensure input2 is lower triangular for cho_solve + L = input2 if not upper else input2.T + # Use cho_solve to solve the linear system + solution = jax.scipy.linalg.cho_solve((L, True), input) + return solution + + +@op(torch.ops.aten.special_zeta) +def _aten_special_zeta(x, q): + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = jax.scipy.special.zeta(x, q) + if isinstance(x, int) or isinstance(q, int): + res = res.astype(new_dtype) + return res # jax.scipy.special.zeta(x, q) + + # aten.igammac @op(torch.ops.aten.igammac) def _aten_igammac(input, other): @@ -268,8 +308,13 @@ def _torch_binary_scalar_type(scalar, tensor): @op(torch.ops.aten.searchsorted.Tensor) -def _aten_searchsorted(sorted_sequence, values): - return jnp.searchsorted(sorted_sequence, values) +def _aten_searchsorted(sorted_sequence, values): + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = jnp.searchsorted(sorted_sequence, values) + if sorted_sequence.dtype == np.dtype(np.int32) or sorted_sequence.dtype == np.dtype(np.int32): + # res = res.astype(new_dtype) + res = res.astype(np.dtype(np.int64)) + return res # jnp.searchsorted(sorted_sequence, values) @op(torch.ops.aten.sub.Tensor) @@ -284,6 +329,21 @@ def _aten_sub(x, y, alpha=1): return x - y*alpha +@op(torch.ops.aten.numpy_T) +def _aten_numpy_T(input): + """ + Jax implementation of torch.numpy_T. + + Args: + input: JAX array. + + Returns: + Transposed JAX array. + """ + return jnp.transpose(input) + + + @op(torch.ops.aten.mm) def _aten_mm(x, y): res = x @ y @@ -2735,9 +2795,24 @@ def _aten_nextafter(input, other, *, out=None): return jnp.nextafter(input, other) +@op(torch.ops.aten.nonzero_static) +def _aten_nonzero_static(input, size, fill_value = -1): + indices = jnp.argwhere(input) + + if size < indices.shape[0]: + indices = indices[:size] + elif size > indices.shape[0]: + padding = jnp.full((size - indices.shape[0], indices.shape[1]), fill_value, dtype=indices.dtype) + indices = jnp.concatenate((indices, padding)) + + return indices + + # aten.nonzero @op(torch.ops.aten.nonzero) -def _aten_nonzero(x): +def _aten_nonzero(x, as_tuple=False): + if jnp.ndim(x) == 0 and (as_tuple or x.item()==0): + return torch.empty(0, 0, dtype=torch.int64) if jnp.ndim(x) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) res = torch.empty(1, 0, dtype=torch.int64) return jnp.array(res.numpy())