diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 650c1b620cf1..3cc14a5490c2 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2769,14 +2769,13 @@ def _conj_transpose_rule(t, x, *, input_dtype): abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs') mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs)) -def _abs_jvp_rule(g, ans, x): +def _abs_jvp_rule(g, x): if _iscomplex(x): - return _maybe_real(mul(g, div(_maybe_conj(x), - _replace_zero(convert_element_type(ans, _dtype(x)))))) + a = atan2(imag(x), real(x)) + return _maybe_real(mul(g, complex(cos(a), neg(sin(a))))) else: return select(ge(x, _zero(x)), g, neg(g)) -ad.defjvp2(abs_p, _abs_jvp_rule) -_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x +ad.defjvp(abs_p, _abs_jvp_rule) _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 62b0fc994e60..ec3f69261337 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6285,6 +6285,30 @@ def testGradLdexp(self, n, dtype): x = rng((), dtype) check_grads(lambda x: jnp.ldexp(x, n), (x,), 1) + @parameterized.parameters( + [ + (*parts, dtype) + for parts in [ + ((0.0, 0.0), (1.0, 0.0)), + ((np.inf, 0.0), (1.0, 0.0)), + ((-np.inf, 0.0), (-1.0, 0.0)), + ((0.0, np.inf), (0.0, -1.0)), + ((0.0, -np.inf), (0.0, 1.0)), + ((np.inf, np.inf), (0.5 * np.sqrt(2), -0.5 * np.sqrt(2))), + ((np.inf, -np.inf), (0.5 * np.sqrt(2), 0.5 * np.sqrt(2))), + ((-np.inf, np.inf), (-0.5 * np.sqrt(2), -0.5 * np.sqrt(2))), + ((-np.inf, -np.inf), (-0.5 * np.sqrt(2), 0.5 * np.sqrt(2))), + ] + for dtype in complex_dtypes + ] + ) + def testComplexAbsGradInf(self, input_parts, grad_parts, dtype): + # https://github.com/jax-ml/jax/issues/25681 + x = jax.lax.complex(*input_parts).astype(dtype) + expected = jax.lax.complex(*grad_parts).astype(dtype) + g = jax.grad(jnp.abs)(x) + self.assertAllClose(g, expected) + class NumpySignaturesTest(jtu.JaxTestCase):