From 6d690904dc92e16a3041e495421f177575a9f4aa Mon Sep 17 00:00:00 2001 From: Simon Teo Date: Sat, 19 Oct 2024 01:21:40 +0800 Subject: [PATCH] Fix torch digamma (#7383) (#8260) Co-authored-by: Simon Teo --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aabd2a11352..befaae21517 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -18,7 +18,6 @@ "cholesky", "cholesky_solve", "diagonal_copy", - "digamma", "geqrf", "histogram", # hard op: AssertionError: Tensor-likes are not close! "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 05f440cae98..bee6a7c3810 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2394,6 +2394,12 @@ def _aten_hypot(input, other): return jnp.hypot(input, other) +@op(torch.ops.aten.digamma) +def _aten_digamma(input, *, out=None): + res = jax.scipy.special.digamma(input).astype(jnp.float32) + # replace indices where input == 0 with -inf in res + return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res) + @op(torch.ops.aten.igamma) def _aten_igamma(input, other): return jax.scipy.special.gammainc(input, other)