From 35c2d1f1fa51fb530bc46882695e0a8fb3204cfb Mon Sep 17 00:00:00 2001 From: Quentin Berthet Date: Thu, 16 Nov 2023 06:20:45 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 583021137 --- jaxopt/_src/perturbations.py | 25 ++++++++++++++++-- tests/perturbations_test.py | 50 ++++++++++++++++++++++++++---------- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/jaxopt/_src/perturbations.py b/jaxopt/_src/perturbations.py index f5ba1780..021baee4 100644 --- a/jaxopt/_src/perturbations.py +++ b/jaxopt/_src/perturbations.py @@ -202,7 +202,8 @@ def pert_jvp(tangent, _, inputs, rng): def make_perturbed_fun(fun: Callable[[jax.Array], float], num_samples: int = 1000, sigma: float = 0.1, - noise=Gumbel()): + noise=Gumbel(), + control_variate: bool = False): """Transforms a function into a differentiable version with perturbations. Args: @@ -214,6 +215,8 @@ def make_perturbed_fun(fun: Callable[[jax.Array], float], noise: a distribution object that must implement a sample function and a log-pdf of the desired distribution, similar to the examples above. Default is Gumbel distribution. + control_variate : Boolean indicating whether a control variate is used in + the Monte-Carlo estimate of the Jacobian. Returns: A function with the same signature (and an rng) that can be differentiated. @@ -264,6 +267,24 @@ def pert_jvp(tangent, _, inputs, rng): tangent_out = jnp.squeeze(jnp.reshape(tangent_flat, output_pert.shape[1:])) return tangent_out - forward_pert.defjvps(pert_jvp, None) + def pert_jvp_control_variate(tangent, _, inputs, rng): + samples = noise.sample(seed=rng, + sample_shape=(num_samples,) + inputs.shape) + output_pert = jax.vmap(fun)(inputs + sigma * samples)[..., jnp.newaxis] + output = fun(inputs) + # noise.log_prob corresponds to -nu in the paper. + nabla_z_flat = -jax.vmap(jax.grad(noise.log_prob))(samples.reshape([-1])) + tangent_flat = 1.0 / (num_samples * sigma) * jnp.einsum( + 'nd,ne,e->d', + jnp.reshape(output_pert - output, (num_samples, -1)), + jnp.reshape(nabla_z_flat, (num_samples, -1)), + jnp.reshape(tangent, (-1,))) + tangent_out = jnp.squeeze(jnp.reshape(tangent_flat, output_pert.shape[1:])) + return tangent_out + + if control_variate: + forward_pert.defjvps(pert_jvp_control_variate, None) + else: + forward_pert.defjvps(pert_jvp, None) return forward_pert diff --git a/tests/perturbations_test.py b/tests/perturbations_test.py index 663da2c6..5375ec34 100644 --- a/tests/perturbations_test.py +++ b/tests/perturbations_test.py @@ -22,6 +22,7 @@ from jaxopt import perturbations from jaxopt._src import test_util from jaxopt import loss +from absl.testing import parameterized def one_hot_argmax(inputs: jnp.ndarray) -> jnp.ndarray: @@ -171,7 +172,7 @@ def square_loss_soft(theta): """ Test ensuring that for small sigma, control variate indeed leads - to a Jacobian that is closter to that of the softmax. + to a Jacobian that is closer to that of the softmax. """ sigma = 0.01 @@ -200,7 +201,7 @@ def square_loss_soft(theta): jp_rv = jac_pert_argmax_fun_rv(theta_matrix[0], rng) # distance of pert argmax jacobians from softmax jacobian - jac_dist= jnp.linalg.norm((js - jp) ** 2) + jac_dist = jnp.linalg.norm((js - jp) ** 2) jac_dist_rv = jnp.linalg.norm((js - jp_rv) ** 2) self.assertLessEqual(jac_dist_rv, jac_dist) @@ -506,8 +507,10 @@ def setUp(self): [[-0.6, 1.9, -0.2, 1.7, -1.0], [-0.6, 1.0, -0.2, 1.8, -1.0]], dtype=jnp.float32, ) - - def test_scalar_small_sigma(self): + @parameterized.product( + control_variate=[True, False], + ) + def test_scalar_small_sigma(self, control_variate): """Checks that the perturbed scalar is close to the max for small sigma.""" pert_scalar_small_sigma_fun = jax.jit( perturbations.make_perturbed_fun( @@ -515,6 +518,7 @@ def test_scalar_small_sigma(self): num_samples=self.num_samples, sigma=1e-7, noise=perturbations.Gumbel(), + control_variate=control_variate ) ) rngs_batch = jax.random.split(self.rng, 2) @@ -526,7 +530,10 @@ def test_scalar_small_sigma(self): pert_scalar_small_sigma, jnp.array([5.2, 4.4]), atol=1e-6 ) - def test_grads(self): + @parameterized.product( + control_variate=[True, False], + ) + def test_grads(self, control_variate): """Tests that the gradients have the correct shape.""" pert_scalar_fun = jax.jit( perturbations.make_perturbed_fun( @@ -534,6 +541,7 @@ def test_grads(self): num_samples=self.num_samples, sigma=self.sigma, noise=perturbations.Gumbel(), + control_variate=control_variate ) ) @@ -541,7 +549,10 @@ def test_grads(self): self.assertArraysEqual(grad_pert.shape, self.theta_batch.shape) - def test_noise_iid(self): + @parameterized.product( + control_variate=[True, False], + ) + def test_noise_iid(self, control_variate): """Checks that different elements of the batch have different noises.""" pert_scalar_fun = jax.jit( perturbations.make_perturbed_fun( @@ -549,20 +560,25 @@ def test_noise_iid(self): num_samples=self.num_samples, sigma=self.sigma, noise=perturbations.Gumbel(), + control_variate=control_variate ) ) theta_batch_repeat = jnp.array([[-0.6, 1.9, -0.2, 1.1, -1.0], [-0.6, 1.9, -0.2, 1.1, -1.0]], dtype=jnp.float32) rngs_batch = jax.random.split(self.rng, 2) - pert_scalar_repeat = jax.vmap(pert_scalar_fun)(theta_batch_repeat, + pert_scalar_repeat = jax.vmap(pert_scalar_fun)(theta_batch_repeat, rngs_batch) - self.assertArraysAllClose(pert_scalar_repeat[0], pert_scalar_repeat[1], + self.assertArraysAllClose(pert_scalar_repeat[0], + pert_scalar_repeat[1], atol=2e-2) delta_noise = pert_scalar_repeat[0] - pert_scalar_repeat[1] self.assertNotAlmostEqual(jnp.linalg.norm(delta_noise), 0) - def test_distrax(self): + @parameterized.product( + control_variate=[True, False], + ) + def test_distrax(self, control_variate): """Checks that the function is compatible with distrax distributions.""" try: import distrax @@ -583,12 +599,17 @@ def test_distrax(self): fun=scalar_function, num_samples=self.num_samples, sigma=self.sigma, - noise=distrax.Normal(loc=0., scale=1.))) + noise=distrax.Normal(loc=0., scale=1.), + control_variate=control_variate)) + dist_scalar = jax.vmap(dist_scalar_fun)(theta_batch, rngs_batch) self.assertArraysAllClose(pert_scalar, dist_scalar, atol=1e-6) - def test_grad_finite_diff(self): + @parameterized.product( + control_variate=[True, False], + ) + def test_grad_finite_diff(self, control_variate): theta = jnp.array([-0.8, 0.6, 1.2, -1.0, -0.7, 0.6, 1.1, -1.0, 0.4]) # High value of num_samples for this specific test. Not required in normal # usecases, as in learning tasks. @@ -596,7 +617,8 @@ def test_grad_finite_diff(self): fun=scalar_function, num_samples=100_000, sigma=self.sigma, - noise=perturbations.Normal())) + noise=perturbations.Normal(), + control_variate=control_variate)) gradient_pert = jax.grad(pert_scalar_fun)(theta, self.rng) eps = 1e-2 @@ -606,13 +628,13 @@ def test_grad_finite_diff(self): pert_scalar_fun(theta - eps * h, rngs[0])) / (2 * eps) delta_lin = jnp.sum(gradient_pert * h) - self.assertArraysAllClose(delta_num, delta_lin, rtol=3e-2) + self.assertArraysAllClose(delta_num, delta_lin, rtol=2e-1) def compare_grads(self): """Checks composition of gradients with only one sample.""" pert_ranks_fun = jax.jit(perturbations.make_perturbed_argmax( argmax_fun=ranks, - num_sample=1, + num_samples=1, sigma=self.sigma)) pert_scalar_fun = jax.jit(perturbations.make_perturbed_fun( fun=scalar_function,