Skip to content

Commit

Permalink
move update functionality to parameterization
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schubert authored and Martin Schubert committed Oct 20, 2024
1 parent 2f688b0 commit 88a28e5
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 45 deletions.
6 changes: 5 additions & 1 deletion src/invrs_opt/experimental/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def params_fn(
response = json.loads(get_response.text)
return json_utils.pytree_from_json(response[labels.PARAMS])

return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
return base.Optimizer(
init=init_fn,
update=update_fn, # type: ignore[arg-type]
params=params_fn,
)


# -----------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion src/invrs_opt/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
from typing import Any, Protocol

import jax.numpy as jnp
import optax # type: ignore[import-untyped]
from totypes import json_utils

Expand Down Expand Up @@ -34,7 +35,7 @@ def __call__(
self,
*,
grad: PyTree,
value: float,
value: jnp.ndarray,
params: PyTree,
state: PyTree,
) -> PyTree:
Expand Down
44 changes: 30 additions & 14 deletions src/invrs_opt/optimizers/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def params_fn(state: LbfgsbState) -> PyTree:
def update_fn(
*,
grad: PyTree,
value: float,
value: jnp.ndarray,
params: PyTree,
state: LbfgsbState,
) -> LbfgsbState:
Expand All @@ -349,12 +349,14 @@ def _update_pure(
) -> Tuple[PyTree, NumpyLbfgsbDict]:
assert onp.size(value) == 1
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
flat_latent_params = scipy_lbfgsb_state.x.copy()
scipy_lbfgsb_state.update(
grad=onp.array(flat_latent_grad, dtype=onp.float64),
value=onp.array(value, dtype=onp.float64),
)
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
return flat_latent_params, scipy_lbfgsb_state.to_dict()
updated_flat_latent_params = scipy_lbfgsb_state.x
flat_latent_updates = updated_flat_latent_params - flat_latent_params
return flat_latent_updates, scipy_lbfgsb_state.to_dict()

step, _, latent_params, jax_lbfgsb_state = state
metadata, latents = param_base.partition_density_metadata(latent_params)
Expand Down Expand Up @@ -395,16 +397,21 @@ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
latents_grad
) # type: ignore[no-untyped-call]

flat_latents, jax_lbfgsb_state = jax.pure_callback(
flat_latent_updates, jax_lbfgsb_state = jax.pure_callback(
_update_pure,
(flat_latents_grad, jax_lbfgsb_state),
flat_latents_grad,
value,
jax_lbfgsb_state,
)
latents = unflatten_fn(flat_latents)
latent_params = param_base.combine_density_metadata(metadata, latents)
latent_params = _update_parameterized_densities(latent_params, step)
latent_updates = unflatten_fn(flat_latent_updates)
latent_params = _apply_updates(
params=latent_params,
updates=param_base.combine_density_metadata(metadata, latent_updates),
value=value,
step=step,
)
latent_params = _clip(latent_params)
params = _params_from_latent_params(latent_params)
return step + 1, params, latent_params, jax_lbfgsb_state

Expand Down Expand Up @@ -433,15 +440,24 @@ def _leaf_params_from_latents(leaf: Any) -> Any:
is_leaf=_is_parameterized_density,
)

def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
def _update_leaf(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.update(leaf, step)
def _apply_updates(
params: PyTree,
updates: PyTree,
value: jnp.ndarray,
step: int,
) -> PyTree:
def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
if _is_parameterized_density(leaf):
return density_parameterization.update(
params=leaf, updates=update, value=value, step=step
)
else:
return optax.apply_updates(params=leaf, updates=update)

return tree_util.tree_map(
_update_leaf,
latent_params,
_leaf_apply_updates,
updates,
params,
is_leaf=_is_parameterized_density,
)

Expand Down
39 changes: 25 additions & 14 deletions src/invrs_opt/optimizers/wrapped_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ def params_fn(state: WrappedOptaxState) -> PyTree:
def update_fn(
*,
grad: PyTree,
value: float,
value: jnp.ndarray,
params: PyTree,
state: WrappedOptaxState,
) -> WrappedOptaxState:
"""Updates the state."""
del value, params
del params

step, params, latent_params, opt_state = state
metadata, latents = param_base.partition_density_metadata(latent_params)
Expand Down Expand Up @@ -233,12 +233,14 @@ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
lambda a, b: a + b, latents_grad, constraint_loss_grad
)

updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents)
latents = optax.apply_updates(params=latents, updates=updates)

latent_params = param_base.combine_density_metadata(metadata, latents)
latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
latent_params = _apply_updates(
params=latent_params,
updates=param_base.combine_density_metadata(metadata, latent_updates),
value=value,
step=step,
)
latent_params = _clip(latent_params)
latent_params = _update_parameterized_densities(latent_params, step + 1)
params = _params_from_latent_params(latent_params)
return (step + 1, params, latent_params, opt_state)

Expand Down Expand Up @@ -267,15 +269,24 @@ def _leaf_params_from_latents(leaf: Any) -> Any:
is_leaf=_is_parameterized_density,
)

def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
def _update_leaf(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.update(leaf, step)
def _apply_updates(
params: PyTree,
updates: PyTree,
value: jnp.ndarray,
step: int,
) -> PyTree:
def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
if _is_parameterized_density(leaf):
return density_parameterization.update(
params=leaf, updates=update, value=value, step=step
)
else:
return optax.apply_updates(params=leaf, updates=update)

return tree_util.tree_map(
_update_leaf,
latent_params,
_leaf_apply_updates,
updates,
params,
is_leaf=_is_parameterized_density,
)

Expand Down
8 changes: 7 additions & 1 deletion src/invrs_opt/parameterization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ def __call__(self, params: PyTree) -> jnp.ndarray:
class UpdateFn(Protocol):
"""Performs the required update of a parameterized density for the given step."""

def __call__(self, params: PyTree, step: int) -> PyTree:
def __call__(
self,
params: PyTree,
updates: PyTree,
value: jnp.ndarray,
step: int,
) -> PyTree:
...


Expand Down
16 changes: 13 additions & 3 deletions src/invrs_opt/parameterization/filter_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,20 @@ def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
del params
return jnp.asarray(0.0)

def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams:
def update_fn(
params: FilterProjectParams,
updates: FilterProjectParams,
value: jnp.ndarray,
step: int,
) -> FilterProjectParams:
"""Perform updates to `params` required for the given `step`."""
del step
return params
del step, value
return FilterProjectParams(
latents=tree_util.tree_map(
lambda a, b: a + b, params.latents, updates.latents
),
metadata=params.metadata,
)

return base.Density2DParameterization(
to_density=to_density_fn,
Expand Down
16 changes: 13 additions & 3 deletions src/invrs_opt/parameterization/gaussian_levelset.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,20 @@ def constraints_fn(
pad_pixels=pad_pixels,
)

def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetParams:
def update_fn(
params: GaussianLevelsetParams,
updates: GaussianLevelsetParams,
value: jnp.ndarray,
step: int,
) -> GaussianLevelsetParams:
"""Perform updates to `params` required for the given `step`."""
del step
return params
del step, value
return GaussianLevelsetParams(
latents=tree_util.tree_map(
lambda a, b: a + b, params.latents, updates.latents
),
metadata=params.metadata,
)

return base.Density2DParameterization(
to_density=to_density_fn,
Expand Down
17 changes: 14 additions & 3 deletions src/invrs_opt/parameterization/pixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,20 @@ def constraints_fn(params: PixelParams) -> jnp.ndarray:
del params
return jnp.asarray(0.0)

def update_fn(params: PixelParams, step: int) -> PixelParams:
del step
return params
def update_fn(
params: PixelParams,
updates: PixelParams,
value: jnp.ndarray,
step: int,
) -> PixelParams:
"""Perform updates to `params` required for the given `step`."""
del step, value
return PixelParams(
latents=tree_util.tree_map(
lambda a, b: a + b, params.latents, updates.latents
),
metadata=params.metadata,
)

return base.Density2DParameterization(
from_density=from_density_fn,
Expand Down
9 changes: 7 additions & 2 deletions tests/optimizers/test_lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

import copy
import unittest

import jax
Expand Down Expand Up @@ -691,8 +692,12 @@ def test_variable_parameterization(self):
# at each step.
p = filter_project.filter_project(beta=1)

def update_fn(params, step):
del step
_original_update_fn = copy.deepcopy(p.update)

def update_fn(step, params, value, updates):
params = _original_update_fn(
step=step, params=params, value=value, updates=updates
)
params.metadata.beta += 1
return params

Expand Down
11 changes: 8 additions & 3 deletions tests/optimizers/test_wrapped_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

import copy
import dataclasses
import unittest

Expand Down Expand Up @@ -148,7 +149,7 @@ def loss_fn(params):

onp.testing.assert_array_equal(values, expected_values)

def test_trajectory_matches_scipy_density_2d(self):
def test_trajectory_matches_optax_density_2d(self):
initial_params = {
"a": jnp.asarray([1.0, 2.0]),
"b": types.BoundedArray(
Expand Down Expand Up @@ -391,8 +392,12 @@ def test_variable_parameterization(self):
# at each step.
p = filter_project.filter_project(beta=1)

def update_fn(params, step):
del step
_original_update_fn = copy.deepcopy(p.update)

def update_fn(step, params, value, updates):
params = _original_update_fn(
step=step, params=params, value=value, updates=updates
)
params.metadata.beta += 1
return params

Expand Down

0 comments on commit 88a28e5

Please sign in to comment.