Skip to content

Commit

Permalink
Merge pull request #300 from mblondel:ot_stop_gradient
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 470727718
  • Loading branch information
JAXopt authors committed Aug 29, 2022
2 parents b742705 + 46cac2c commit 745ab09
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions jaxopt/_src/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,18 @@ def projection_box_section(x: jnp.ndarray,

def _max_l2(x, marginal_b, gamma):
scale = gamma * marginal_b
p = projection_simplex(x / scale)
x_scale = x / scale
p = projection_simplex(x_scale)
# From Danskin's theorem, we do not need to backpropagate
# through projection_simplex.
p = jax.lax.stop_gradient(p)
return jnp.dot(x, p) - 0.5 * scale * jnp.dot(p, p)


def _max_ent(x, marginal_b, gamma):
return gamma * logsumexp(x / gamma) - gamma * jnp.log(marginal_b)


_max_l2_vmap = jax.vmap(_max_l2, in_axes=(1, 0, None))
_max_l2_grad_vmap = jax.vmap(jax.grad(_max_l2), in_axes=(1, 0, None))

Expand Down Expand Up @@ -771,4 +776,4 @@ def kl_projection_birkhoff(sim_matrix: jnp.ndarray,
return kl_projection_transport(sim_matrix=sim_matrix,
marginals=(marginals_a, marginals_b),
make_solver=make_solver,
use_semi_dual=use_semi_dual)
use_semi_dual=use_semi_dual)

0 comments on commit 745ab09

Please sign in to comment.