Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583021137
  • Loading branch information
q-berthet authored and JAXopt authors committed Nov 16, 2023
1 parent 5b9b62c commit 35c2d1f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 16 deletions.
25 changes: 23 additions & 2 deletions jaxopt/_src/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def pert_jvp(tangent, _, inputs, rng):
def make_perturbed_fun(fun: Callable[[jax.Array], float],
num_samples: int = 1000,
sigma: float = 0.1,
noise=Gumbel()):
noise=Gumbel(),
control_variate: bool = False):
"""Transforms a function into a differentiable version with perturbations.
Args:
Expand All @@ -214,6 +215,8 @@ def make_perturbed_fun(fun: Callable[[jax.Array], float],
noise: a distribution object that must implement a sample function and a
log-pdf of the desired distribution, similar to the examples above.
Default is Gumbel distribution.
control_variate : Boolean indicating whether a control variate is used in
the Monte-Carlo estimate of the Jacobian.
Returns:
A function with the same signature (and an rng) that can be differentiated.
Expand Down Expand Up @@ -264,6 +267,24 @@ def pert_jvp(tangent, _, inputs, rng):
tangent_out = jnp.squeeze(jnp.reshape(tangent_flat, output_pert.shape[1:]))
return tangent_out

forward_pert.defjvps(pert_jvp, None)
def pert_jvp_control_variate(tangent, _, inputs, rng):
samples = noise.sample(seed=rng,
sample_shape=(num_samples,) + inputs.shape)
output_pert = jax.vmap(fun)(inputs + sigma * samples)[..., jnp.newaxis]
output = fun(inputs)
# noise.log_prob corresponds to -nu in the paper.
nabla_z_flat = -jax.vmap(jax.grad(noise.log_prob))(samples.reshape([-1]))
tangent_flat = 1.0 / (num_samples * sigma) * jnp.einsum(
'nd,ne,e->d',
jnp.reshape(output_pert - output, (num_samples, -1)),
jnp.reshape(nabla_z_flat, (num_samples, -1)),
jnp.reshape(tangent, (-1,)))
tangent_out = jnp.squeeze(jnp.reshape(tangent_flat, output_pert.shape[1:]))
return tangent_out

if control_variate:
forward_pert.defjvps(pert_jvp_control_variate, None)
else:
forward_pert.defjvps(pert_jvp, None)

return forward_pert
50 changes: 36 additions & 14 deletions tests/perturbations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jaxopt import perturbations
from jaxopt._src import test_util
from jaxopt import loss
from absl.testing import parameterized


def one_hot_argmax(inputs: jnp.ndarray) -> jnp.ndarray:
Expand Down Expand Up @@ -171,7 +172,7 @@ def square_loss_soft(theta):

"""
Test ensuring that for small sigma, control variate indeed leads
to a Jacobian that is closter to that of the softmax.
to a Jacobian that is closer to that of the softmax.
"""
sigma = 0.01

Expand Down Expand Up @@ -200,7 +201,7 @@ def square_loss_soft(theta):
jp_rv = jac_pert_argmax_fun_rv(theta_matrix[0], rng)

# distance of pert argmax jacobians from softmax jacobian
jac_dist= jnp.linalg.norm((js - jp) ** 2)
jac_dist = jnp.linalg.norm((js - jp) ** 2)
jac_dist_rv = jnp.linalg.norm((js - jp_rv) ** 2)
self.assertLessEqual(jac_dist_rv, jac_dist)

Expand Down Expand Up @@ -506,15 +507,18 @@ def setUp(self):
[[-0.6, 1.9, -0.2, 1.7, -1.0], [-0.6, 1.0, -0.2, 1.8, -1.0]],
dtype=jnp.float32,
)

def test_scalar_small_sigma(self):
@parameterized.product(
control_variate=[True, False],
)
def test_scalar_small_sigma(self, control_variate):
"""Checks that the perturbed scalar is close to the max for small sigma."""
pert_scalar_small_sigma_fun = jax.jit(
perturbations.make_perturbed_fun(
fun=scalar_function,
num_samples=self.num_samples,
sigma=1e-7,
noise=perturbations.Gumbel(),
control_variate=control_variate
)
)
rngs_batch = jax.random.split(self.rng, 2)
Expand All @@ -526,43 +530,55 @@ def test_scalar_small_sigma(self):
pert_scalar_small_sigma, jnp.array([5.2, 4.4]), atol=1e-6
)

def test_grads(self):
@parameterized.product(
control_variate=[True, False],
)
def test_grads(self, control_variate):
"""Tests that the gradients have the correct shape."""
pert_scalar_fun = jax.jit(
perturbations.make_perturbed_fun(
fun=scalar_function,
num_samples=self.num_samples,
sigma=self.sigma,
noise=perturbations.Gumbel(),
control_variate=control_variate
)
)

grad_pert = jax.grad(pert_scalar_fun)(self.theta_batch, self.rng)

self.assertArraysEqual(grad_pert.shape, self.theta_batch.shape)

def test_noise_iid(self):
@parameterized.product(
control_variate=[True, False],
)
def test_noise_iid(self, control_variate):
"""Checks that different elements of the batch have different noises."""
pert_scalar_fun = jax.jit(
perturbations.make_perturbed_fun(
fun=scalar_function,
num_samples=self.num_samples,
sigma=self.sigma,
noise=perturbations.Gumbel(),
control_variate=control_variate
)
)
theta_batch_repeat = jnp.array([[-0.6, 1.9, -0.2, 1.1, -1.0],
[-0.6, 1.9, -0.2, 1.1, -1.0]],
dtype=jnp.float32)
rngs_batch = jax.random.split(self.rng, 2)
pert_scalar_repeat = jax.vmap(pert_scalar_fun)(theta_batch_repeat,
pert_scalar_repeat = jax.vmap(pert_scalar_fun)(theta_batch_repeat,
rngs_batch)
self.assertArraysAllClose(pert_scalar_repeat[0], pert_scalar_repeat[1],
self.assertArraysAllClose(pert_scalar_repeat[0],
pert_scalar_repeat[1],
atol=2e-2)
delta_noise = pert_scalar_repeat[0] - pert_scalar_repeat[1]
self.assertNotAlmostEqual(jnp.linalg.norm(delta_noise), 0)

def test_distrax(self):
@parameterized.product(
control_variate=[True, False],
)
def test_distrax(self, control_variate):
"""Checks that the function is compatible with distrax distributions."""
try:
import distrax
Expand All @@ -583,20 +599,26 @@ def test_distrax(self):
fun=scalar_function,
num_samples=self.num_samples,
sigma=self.sigma,
noise=distrax.Normal(loc=0., scale=1.)))
noise=distrax.Normal(loc=0., scale=1.),
control_variate=control_variate))

dist_scalar = jax.vmap(dist_scalar_fun)(theta_batch, rngs_batch)

self.assertArraysAllClose(pert_scalar, dist_scalar, atol=1e-6)

def test_grad_finite_diff(self):
@parameterized.product(
control_variate=[True, False],
)
def test_grad_finite_diff(self, control_variate):
theta = jnp.array([-0.8, 0.6, 1.2, -1.0, -0.7, 0.6, 1.1, -1.0, 0.4])
# High value of num_samples for this specific test. Not required in normal
# usecases, as in learning tasks.
pert_scalar_fun = jax.jit(perturbations.make_perturbed_fun(
fun=scalar_function,
num_samples=100_000,
sigma=self.sigma,
noise=perturbations.Normal()))
noise=perturbations.Normal(),
control_variate=control_variate))

gradient_pert = jax.grad(pert_scalar_fun)(theta, self.rng)
eps = 1e-2
Expand All @@ -606,13 +628,13 @@ def test_grad_finite_diff(self):
pert_scalar_fun(theta - eps * h, rngs[0])) / (2 * eps)
delta_lin = jnp.sum(gradient_pert * h)

self.assertArraysAllClose(delta_num, delta_lin, rtol=3e-2)
self.assertArraysAllClose(delta_num, delta_lin, rtol=2e-1)

def compare_grads(self):
"""Checks composition of gradients with only one sample."""
pert_ranks_fun = jax.jit(perturbations.make_perturbed_argmax(
argmax_fun=ranks,
num_sample=1,
num_samples=1,
sigma=self.sigma))
pert_scalar_fun = jax.jit(perturbations.make_perturbed_fun(
fun=scalar_function,
Expand Down

0 comments on commit 35c2d1f

Please sign in to comment.