Skip to content

Commit

Permalink
[torch_xla2] add nn.functional.multilabel_margin_loss ... randint (
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 15, 2024
1 parent bd4006e commit 91f5c8a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 13 deletions.
12 changes: 2 additions & 10 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"byte",
"cat",
"cholesky",
"cholesky_solve",
"diagonal_copy",
"geqrf",
"histogram", # hard op: AssertionError: Tensor-likes are not close!
"histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got <class 'list'> at position 1.
"index_reduce",
"kthvalue",
"linalg.cholesky",
"linalg.cholesky_ex",
"linalg.det",
"linalg.ldl_solve",
"linalg.lu_solve",
Expand All @@ -49,13 +46,7 @@
"nn.functional.max_pool2d",
"nn.functional.max_pool3d",
"nn.functional.multi_head_attention_forward",
"nn.functional.multilabel_margin_loss",
"nn.functional.pairwise_distance",
"nn.functional.poisson_nll_loss",
"nn.functional.rrelu",
"nn.functional.upsample_nearest",
"nonzero",
"nonzero_static",
"normal",
"ormqr",
"pca_lowrank",
Expand All @@ -67,7 +58,6 @@
"special.zeta",
"unfold_copy",
"unfold",
"randint",
}

not_support_ops_list = {
Expand All @@ -77,6 +67,7 @@
"ceil", # only failed with python 3.9
"trunc", # only failed with python 3.9
"to_sparse", # We are not supporting sparse tensors yet.
"nn.functional.rrelu", # pure torch result match torch_xla2 test result, only OpInfo mismatch: https://gist.github.com/ManfeiBai/1a449b15f4e946bfcaa3e5ef86da20f4
}

# These inputs are themselves views
Expand Down Expand Up @@ -106,6 +97,7 @@
'cauchy',
'exponential',
'log_normal',
'randint',
}

atol_dict = {"linalg.eig": (2e0, 3e0),
Expand Down
81 changes: 78 additions & 3 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,46 @@ def _aten_index_select(x, dim, index):
return jnp.take(x, index, dim)


@op(torch.ops.aten.cholesky)
def _aten_cholesky(input, upper=False):
return jax.scipy.linalg.cholesky(input, lower=(not upper))


@op(torch.ops.aten.linalg_cholesky_ex)
def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
if check_errors:
raise NotImplementedError(
"check_errors=True is not supported in this JAX implementation. "
"Check for positive definiteness using jnp.linalg.eigvalsh before "
"calling this function."
)

L = jax.scipy.linalg.cholesky(input, lower=not upper)
if len(L.shape) >2:
info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32)
else:
info = jnp.array(0, dtype=jnp.int32)
return L, info


@op(torch.ops.aten.cholesky_solve)
def _aten_cholesky_solve(input, input2, upper=False):
# Ensure input2 is lower triangular for cho_solve
L = input2 if not upper else input2.T
# Use cho_solve to solve the linear system
solution = jax.scipy.linalg.cho_solve((L, True), input)
return solution


@op(torch.ops.aten.special_zeta)
def _aten_special_zeta(x, q):
new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
res = jax.scipy.special.zeta(x, q)
if isinstance(x, int) or isinstance(q, int):
res = res.astype(new_dtype)
return res # jax.scipy.special.zeta(x, q)


# aten.igammac
@op(torch.ops.aten.igammac)
def _aten_igammac(input, other):
Expand Down Expand Up @@ -268,8 +308,13 @@ def _torch_binary_scalar_type(scalar, tensor):


@op(torch.ops.aten.searchsorted.Tensor)
def _aten_searchsorted(sorted_sequence, values):
return jnp.searchsorted(sorted_sequence, values)
def _aten_searchsorted(sorted_sequence, values):
new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
res = jnp.searchsorted(sorted_sequence, values)
if sorted_sequence.dtype == np.dtype(np.int32) or sorted_sequence.dtype == np.dtype(np.int32):
# res = res.astype(new_dtype)
res = res.astype(np.dtype(np.int64))
return res # jnp.searchsorted(sorted_sequence, values)


@op(torch.ops.aten.sub.Tensor)
Expand All @@ -284,6 +329,21 @@ def _aten_sub(x, y, alpha=1):
return x - y*alpha


@op(torch.ops.aten.numpy_T)
def _aten_numpy_T(input):
"""
Jax implementation of torch.numpy_T.
Args:
input: JAX array.
Returns:
Transposed JAX array.
"""
return jnp.transpose(input)



@op(torch.ops.aten.mm)
def _aten_mm(x, y):
res = x @ y
Expand Down Expand Up @@ -2735,9 +2795,24 @@ def _aten_nextafter(input, other, *, out=None):
return jnp.nextafter(input, other)


@op(torch.ops.aten.nonzero_static)
def _aten_nonzero_static(input, size, fill_value = -1):
indices = jnp.argwhere(input)

if size < indices.shape[0]:
indices = indices[:size]
elif size > indices.shape[0]:
padding = jnp.full((size - indices.shape[0], indices.shape[1]), fill_value, dtype=indices.dtype)
indices = jnp.concatenate((indices, padding))

return indices


# aten.nonzero
@op(torch.ops.aten.nonzero)
def _aten_nonzero(x):
def _aten_nonzero(x, as_tuple=False):
if jnp.ndim(x) == 0 and (as_tuple or x.item()==0):
return torch.empty(0, 0, dtype=torch.int64)
if jnp.ndim(x) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
res = torch.empty(1, 0, dtype=torch.int64)
return jnp.array(res.numpy())
Expand Down

0 comments on commit 91f5c8a

Please sign in to comment.