Skip to content

Commit

Permalink
fix step-dependent parameterization
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schubert authored and Martin Schubert committed Aug 15, 2024
1 parent e921f43 commit 4abc138
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 80 deletions.
145 changes: 84 additions & 61 deletions src/invrs_opt/optimizers/wrapped_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,64 +178,15 @@ def parameterized_wrapped_optax(
if density_parameterization is None:
density_parameterization = pixel.pixel()

def _init_latents(params: PyTree) -> PyTree:
def _leaf_init_latents(leaf: Any) -> Any:
leaf = _clip(leaf)
if not _is_density(leaf):
return leaf
return density_parameterization.from_density(leaf)

return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)

def _params_from_latents(params: PyTree) -> PyTree:
def _leaf_params_from_latents(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.to_density(leaf)

return tree_util.tree_map(
_leaf_params_from_latents,
params,
is_leaf=_is_parameterized_density,
)

def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
def _constraint_loss_leaf(
params: parameterization_base.ParameterizedDensity2DArrayBase,
) -> jnp.ndarray:
constraints = density_parameterization.constraints(params)
constraints = tree_util.tree_map(
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
constraints,
)
return jnp.sum(jnp.asarray(constraints))

losses = [0.0] + [
_constraint_loss_leaf(p)
for p in tree_util.tree_leaves(
latent_params, is_leaf=_is_parameterized_density
)
if _is_parameterized_density(p)
]
return penalty * jnp.sum(jnp.asarray(losses))

def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
def _update_leaf(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.update(leaf, step)

return tree_util.tree_map(
_update_leaf,
latent_params,
is_leaf=_is_parameterized_density,
)

def init_fn(params: PyTree) -> WrappedOptaxState:
"""Initializes the optimization state."""
latent_params = _init_latents(params)
params = _params_from_latents(latent_params)
return 0, params, latent_params, opt.init(latent_params)
return (
0, # step
_params_from_latents(latent_params), # params
latent_params, # latent params
opt.init(tree_util.tree_leaves(latent_params)), # opt state
)

def params_fn(state: WrappedOptaxState) -> PyTree:
"""Returns the parameters for the given `state`."""
Expand All @@ -252,7 +203,8 @@ def update_fn(
"""Updates the state."""
del value, params

step, _, latent_params, opt_state = state
step, params, latent_params, opt_state = state

_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
(latent_grad,) = vjp_fn(grad)

Expand All @@ -271,14 +223,85 @@ def update_fn(
lambda a, b: a + b, latent_grad, constraint_loss_grad
)

updates, opt_state = opt.update(
updates=latent_grad, state=opt_state, params=latent_params
updates_leaves, opt_state = opt.update(
updates=tree_util.tree_leaves(latent_grad),
state=opt_state,
params=tree_util.tree_leaves(latent_params),
)
latent_params_leaves = optax.apply_updates(
params=tree_util.tree_leaves(latent_params),
updates=updates_leaves,
)
latent_params = optax.apply_updates(params=latent_params, updates=updates)
latent_params = tree_util.tree_unflatten(
treedef=tree_util.tree_structure(latent_params),
leaves=latent_params_leaves,
)

latent_params = _clip(latent_params)
latent_params = _update_parameterized_densities(latent_params, step)
latent_params = _update_parameterized_densities(latent_params, step + 1)
params = _params_from_latents(latent_params)
return step + 1, params, latent_params, opt_state
return (step + 1, params, latent_params, opt_state)

# -------------------------------------------------------------------------
# Functions related to the density parameterization.
# -------------------------------------------------------------------------

def _init_latents(params: PyTree) -> PyTree:
def _leaf_init_latents(leaf: Any) -> Any:
leaf = _clip(leaf)
if not _is_density(leaf):
return leaf
return density_parameterization.from_density(leaf)

return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)

def _params_from_latents(params: PyTree) -> PyTree:
def _leaf_params_from_latents(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.to_density(leaf)

return tree_util.tree_map(
_leaf_params_from_latents,
params,
is_leaf=_is_parameterized_density,
)

def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
def _update_leaf(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.update(leaf, step)

return tree_util.tree_map(
_update_leaf,
latent_params,
is_leaf=_is_parameterized_density,
)

# -------------------------------------------------------------------------
# Functions related to the constraints to be minimized.
# -------------------------------------------------------------------------

def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
def _constraint_loss_leaf(
params: parameterization_base.ParameterizedDensity2DArrayBase,
) -> jnp.ndarray:
constraints = density_parameterization.constraints(params)
constraints = tree_util.tree_map(
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
constraints,
)
return jnp.sum(jnp.asarray(constraints))

losses = [0.0] + [
_constraint_loss_leaf(p)
for p in tree_util.tree_leaves(
latent_params, is_leaf=_is_parameterized_density
)
if _is_parameterized_density(p)
]
return penalty * jnp.sum(jnp.asarray(losses))

return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)

Expand Down
41 changes: 23 additions & 18 deletions tests/optimizers/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jax.numpy as jnp
import numpy as onp
import optax
from jax import tree_util
from parameterized import parameterized
from totypes import json_utils, symmetry, types

Expand Down Expand Up @@ -76,7 +77,7 @@
),
},
)
PARAMS_WITH_BOUNDED_ARRAY_JAX = jax.tree_util.tree_map(
PARAMS_WITH_BOUNDED_ARRAY_JAX = tree_util.tree_map(
jnp.asarray, PARAMS_WITH_BOUNDED_ARRAY_NUMPY
)
PARAMS_WITH_DENSITY_2D_NUMPY = (
Expand Down Expand Up @@ -135,7 +136,7 @@
),
},
)
PARAMS_WITH_DENSITY_2D_JAX = jax.tree_util.tree_map(
PARAMS_WITH_DENSITY_2D_JAX = tree_util.tree_map(
jnp.asarray, PARAMS_WITH_DENSITY_2D_NUMPY
)
PARAMS = [
Expand All @@ -150,7 +151,7 @@

def _lists_to_tuple(pytree, max_depth=10):
for _ in range(max_depth):
pytree = jax.tree_util.tree_map(
pytree = tree_util.tree_map(
lambda x: tuple(x) if isinstance(x, list) else x,
pytree,
is_leaf=lambda x: isinstance(x, list),
Expand All @@ -170,24 +171,28 @@ class BasicOptimizerTest(unittest.TestCase):
@parameterized.expand(itertools.product(PARAMS, OPTIMIZERS))
def test_state_is_serializable(self, params, opt):
state = opt.init(params)
leaves, treedef = jax.tree_util.tree_flatten(state)

serialized_state = serialize(state)
restored_state = deserialize(serialized_state)
# Serialization/deserialization unavoidably converts tuples to lists.
# Convert back to tuples to facilitate comparison.
restored_state = _lists_to_tuple(restored_state)
restored_leaves, restored_treedef = jax.tree_util.tree_flatten(restored_state)

self.assertEqual(treedef, restored_treedef)
# Serialization/deserialization currently converts tuples to lists. Compare
# tree structures, neglecting difference between tuples and lists.
self.assertEqual(
tree_util.tree_structure(_lists_to_tuple(state)),
tree_util.tree_structure(_lists_to_tuple(restored_state)),
)

for a, b in zip(leaves, restored_leaves):
for a, b in zip(
tree_util.tree_leaves(state),
tree_util.tree_leaves(restored_state),
strict=True,
):
onp.testing.assert_array_equal(a, b)

@parameterized.expand(itertools.product(PARAMS, OPTIMIZERS))
def test_optimize(self, initial_params, opt):
def loss_fn(params):
leaves = jax.tree_util.tree_leaves(params)
leaves = tree_util.tree_leaves(params)
leaves_sum_squared = [jnp.sum(leaf**2) for leaf in leaves]
return jnp.sum(jnp.asarray(leaves_sum_squared))

Expand All @@ -197,16 +202,16 @@ def loss_fn(params):
value, grad = jax.value_and_grad(loss_fn)(params)
state = opt.update(grad=grad, value=value, params=params, state=state)

initial_treedef = jax.tree_util.tree_structure(initial_params)
treedef = jax.tree_util.tree_structure(params)
initial_treedef = tree_util.tree_structure(initial_params)
treedef = tree_util.tree_structure(params)
# Assert that the tree structure (i.e. including auxilliary quantities) is
# preserved by optimization.
self.assertEqual(treedef, initial_treedef)

@parameterized.expand(itertools.product(PARAMS, OPTIMIZERS))
def test_optimize_with_serialization(self, initial_params, opt):
def loss_fn(params):
leaves = jax.tree_util.tree_leaves(params)
leaves = tree_util.tree_leaves(params)
leaves_sum_squared = [jnp.sum(leaf**2) for leaf in leaves]
return jnp.sum(jnp.asarray(leaves_sum_squared))

Expand Down Expand Up @@ -248,16 +253,16 @@ def serdes(x):
# Serialization/deserialization unavoidably converts tuples to lists.
# Convert back to tuples to facilitate comparison.
p = _lists_to_tuple(p)
a_leaves, a_treedef = jax.tree_util.tree_flatten(p)
b_leaves, b_treedef = jax.tree_util.tree_flatten(ep)
a_leaves, a_treedef = tree_util.tree_flatten(p)
b_leaves, b_treedef = tree_util.tree_flatten(ep)
self.assertEqual(a_treedef, b_treedef)
for a, b in zip(a_leaves, b_leaves):
onp.testing.assert_array_equal(a, b)

for g, eg in zip(grad_list, expected_grad_list):
g = _lists_to_tuple(g)
a_leaves, a_treedef = jax.tree_util.tree_flatten(g)
b_leaves, b_treedef = jax.tree_util.tree_flatten(eg)
a_leaves, a_treedef = tree_util.tree_flatten(g)
b_leaves, b_treedef = tree_util.tree_flatten(eg)
self.assertEqual(a_treedef, b_treedef)
for a, b in zip(a_leaves, b_leaves):
onp.testing.assert_array_equal(a, b)
42 changes: 42 additions & 0 deletions tests/optimizers/test_lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

import dataclasses
import unittest

import jax
Expand All @@ -13,6 +14,7 @@
from parameterized import parameterized
from totypes import types

from invrs_opt.parameterization import filter_project
from invrs_opt.optimizers import lbfgsb

jax.config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -682,3 +684,43 @@ def grad_fn(x):

# Compare the first few steps for the two schemes.
onp.testing.assert_allclose(scipy_values[:10], wrapper_values[:10])


class StepVariableParameterizationTest(unittest.TestCase):
def test_variable_parameterization(self):
# Create a custom parameterization whose update method increments `beta` by 1
# at each step.
p = filter_project.filter_project(beta=1)
p.update = lambda x, step: dataclasses.replace(x, beta=x.beta + 1)

opt = lbfgsb.parameterized_lbfgsb(
density_parameterization=p,
penalty=1.0,
)

target = jnp.asarray([[0, 1], [1, 0]], dtype=float)
target = jnp.kron(target, jnp.ones((10, 10)))

density = types.Density2DArray(
array=jnp.full(target.shape, 0.5),
lower_bound=0,
upper_bound=1,
minimum_width=4,
minimum_spacing=4,
)

state = opt.init(density)

def step_fn(state):
def loss_fn(density):
return jnp.sum((density.array - target) ** 2)

params = opt.params(state)
value, grad = jax.value_and_grad(loss_fn)(params)
return opt.update(grad=grad, value=value, params=params, state=state)

for _ in range(10):
state = step_fn(state)

# Check that beta has actually been incremented.
self.assertEqual(state[2].beta, 11)
43 changes: 42 additions & 1 deletion tests/optimizers/test_wrapped_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from parameterized import parameterized
from totypes import types

from invrs_opt.parameterization import transforms
from invrs_opt.parameterization import filter_project, transforms
from invrs_opt.optimizers import wrapped_optax


Expand Down Expand Up @@ -383,3 +383,44 @@ def step_fn(state):
onp.testing.assert_allclose(
batch_values, onp.transpose(no_batch_values, (1, 0)), atol=1e-4
)


class StepVariableParameterizationTest(unittest.TestCase):
def test_variable_parameterization(self):
# Create a custom parameterization whose update method increments `beta` by 1
# at each step.
p = filter_project.filter_project(beta=1)
p.update = lambda x, step: dataclasses.replace(x, beta=x.beta + 1)

opt = wrapped_optax.parameterized_wrapped_optax(
opt=optax.adam(0.01),
density_parameterization=p,
penalty=1.0,
)

target = jnp.asarray([[0, 1], [1, 0]], dtype=float)
target = jnp.kron(target, jnp.ones((10, 10)))

density = types.Density2DArray(
array=jnp.full(target.shape, 0.5),
lower_bound=0,
upper_bound=1,
minimum_width=4,
minimum_spacing=4,
)

state = opt.init(density)

def step_fn(state):
def loss_fn(density):
return jnp.sum((density.array - target) ** 2)

params = opt.params(state)
value, grad = jax.value_and_grad(loss_fn)(params)
return opt.update(grad=grad, value=value, params=params, state=state)

for _ in range(10):
state = step_fn(state)

# Check that beta has actually been incremented.
self.assertEqual(state[2].beta, 11)

0 comments on commit 4abc138

Please sign in to comment.