From 5dbdb8db96a49fefc1de240bc505ffc5ff73a18b Mon Sep 17 00:00:00 2001 From: Simon Teo Date: Wed, 2 Oct 2024 02:02:41 +0800 Subject: [PATCH] Fix lgamma, mvlgamma, multinomial, and nanmedian (#7513) (#8095) Co-authored-by: Simon Teo --- experimental/torch_xla2/test/test_ops.py | 5 +-- .../torch_xla2/torch_xla2/ops/jaten.py | 31 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 27782be7982..08965d085db 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -37,7 +37,6 @@ "igammac", "index_reduce", "kthvalue", - "lgamma", "linalg.cholesky", "linalg.cholesky_ex", "linalg.det", @@ -69,9 +68,6 @@ "lu_unpack", "masked.median", "max_pool2d_with_indices_backward", - "multinomial", - "mvlgamma", - "nanmedian", "new_empty_strided", "nextafter", "nn.functional.adaptive_avg_pool3d", @@ -169,6 +165,7 @@ 'rand', 'rand_like', 'uniform', + 'multinomial', # Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html 'nn.functional.feature_alpha_dropout', } diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5f3a452259d..9ea94bc46c4 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2232,6 +2232,13 @@ def _aten_hypot(input, other): def _aten_igamma(input, other): return jax.scipy.special.gammainc(input, other) +@op(torch.ops.aten.lgamma) +def _aten_lgamma(input, *, out=None): + return jax.scipy.special.gammaln(input).astype(jnp.float32) + +@op(torch.ops.aten.mvlgamma) +def _aten_mvlgamma(input, p, *, out=None): + return jax.scipy.special.multigammaln(input, d) @op(torch.ops.aten.linalg_eig) def _aten_linalg_eig(A): @@ -3935,6 +3942,19 @@ def f(k, carry): return vectorized(self, n.astype(jnp.int64)) +@op(torch.ops.aten.multinomial, needs_env=True) +def _aten_multinomial(input, num_samples, replacement=False, *, generator=None, out=None, env=None): + assert num_samples <= input.shape[-1] or replacement, "cannot take a larger sample than population when replacement=False" + assert jnp.all(input >= 0), "inputs must be non-negative" + key = env.get_and_rotate_prng_key(generator) + if input.ndim == 1: + assert jnp.sum(input) > 0, "rows of input must have non-zero sum" + return jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input) + else: + assert jnp.all(jnp.sum(input, axis=1) > 0), "rows of input must have non-zero sum" + return jnp.array([jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input[i, :]) for i in range(input.shape[0])]) + + @op(torch.ops.aten.narrow) @op(torch.ops.aten.narrow_copy) def _aten_narrow(input, dim, start, length): @@ -4047,6 +4067,17 @@ def _aten_median(self, dim=None, keepdim=False): index = _with_reduction_scalar(_get_median_index, self, dim, keepdim).astype(jnp.int64) return output, index + +@op(torch.ops.aten.nanmedian) +def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None): + output = _with_reduction_scalar(functools.partial(jnp.nanquantile, q=0.5, method='lower'), input, dim=dim, keepdim=keepdim).astype(input.dtype) + if dim is None: + return output + else: + index = _with_reduction_scalar(_get_median_index, input, dim, keepdim).astype(jnp.int64) + return output, index + + def _get_median_index(x, axis=None, keepdims=False): sorted_arg = jnp.argsort(x, axis=axis) n = x.shape[axis] if axis is not None else x.size