Skip to content

Commit

Permalink
Fix torch digamma (#7383) (#8260)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Teo <[email protected]>
  • Loading branch information
simonteozw and Simon Teo authored Oct 18, 2024
1 parent cec5052 commit 6d69090
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"cholesky",
"cholesky_solve",
"diagonal_copy",
"digamma",
"geqrf",
"histogram", # hard op: AssertionError: Tensor-likes are not close!
"histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got <class 'list'> at position 1.
Expand Down
6 changes: 6 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,6 +2394,12 @@ def _aten_hypot(input, other):
return jnp.hypot(input, other)


@op(torch.ops.aten.digamma)
def _aten_digamma(input, *, out=None):
res = jax.scipy.special.digamma(input).astype(jnp.float32)
# replace indices where input == 0 with -inf in res
return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res)

@op(torch.ops.aten.igamma)
def _aten_igamma(input, other):
return jax.scipy.special.gammainc(input, other)
Expand Down

0 comments on commit 6d69090

Please sign in to comment.