Skip to content

Commit

Permalink
Update JVP rule for abs.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Jan 24, 2025
1 parent 1f23253 commit 3bb69bd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
9 changes: 4 additions & 5 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,14 +2769,13 @@ def _conj_transpose_rule(t, x, *, input_dtype):
abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs')
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs))

def _abs_jvp_rule(g, ans, x):
def _abs_jvp_rule(g, x):
if _iscomplex(x):
return _maybe_real(mul(g, div(_maybe_conj(x),
_replace_zero(convert_element_type(ans, _dtype(x))))))
a = atan2(imag(x), real(x))
return _maybe_real(mul(g, complex(cos(a), neg(sin(a)))))
else:
return select(ge(x, _zero(x)), g, neg(g))
ad.defjvp2(abs_p, _abs_jvp_rule)
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
ad.defjvp(abs_p, _abs_jvp_rule)
_maybe_real = lambda x: real(x) if _iscomplex(x) else x

sqrt_p = standard_unop(_float | _complex, 'sqrt')
Expand Down
24 changes: 24 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6285,6 +6285,30 @@ def testGradLdexp(self, n, dtype):
x = rng((), dtype)
check_grads(lambda x: jnp.ldexp(x, n), (x,), 1)

@parameterized.parameters(
[
(*parts, dtype)
for parts in [
((0.0, 0.0), (1.0, 0.0)),
((np.inf, 0.0), (1.0, 0.0)),
((-np.inf, 0.0), (-1.0, 0.0)),
((0.0, np.inf), (0.0, -1.0)),
((0.0, -np.inf), (0.0, 1.0)),
((np.inf, np.inf), (0.5 * np.sqrt(2), -0.5 * np.sqrt(2))),
((np.inf, -np.inf), (0.5 * np.sqrt(2), 0.5 * np.sqrt(2))),
((-np.inf, np.inf), (-0.5 * np.sqrt(2), -0.5 * np.sqrt(2))),
((-np.inf, -np.inf), (-0.5 * np.sqrt(2), 0.5 * np.sqrt(2))),
]
for dtype in complex_dtypes
]
)
def testComplexAbsGradInf(self, input_parts, grad_parts, dtype):
# https://github.com/jax-ml/jax/issues/25681
x = jax.lax.complex(*input_parts).astype(dtype)
expected = jax.lax.complex(*grad_parts).astype(dtype)
g = jax.grad(jnp.abs)(x)
self.assertAllClose(g, expected)


class NumpySignaturesTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 3bb69bd

Please sign in to comment.