diff --git a/src/invrs_opt/experimental/client.py b/src/invrs_opt/experimental/client.py index 0c3fbdc..1fd67d0 100644 --- a/src/invrs_opt/experimental/client.py +++ b/src/invrs_opt/experimental/client.py @@ -132,7 +132,11 @@ def params_fn( response = json.loads(get_response.text) return json_utils.pytree_from_json(response[labels.PARAMS]) - return base.Optimizer(init=init_fn, update=update_fn, params=params_fn) + return base.Optimizer( + init=init_fn, + update=update_fn, # type: ignore[arg-type] + params=params_fn, + ) # ----------------------------------------------------------------------------- diff --git a/src/invrs_opt/optimizers/base.py b/src/invrs_opt/optimizers/base.py index 08a8809..7ec66a0 100644 --- a/src/invrs_opt/optimizers/base.py +++ b/src/invrs_opt/optimizers/base.py @@ -7,6 +7,7 @@ import inspect from typing import Any, Protocol +import jax.numpy as jnp import optax # type: ignore[import-untyped] from totypes import json_utils @@ -34,7 +35,7 @@ def __call__( self, *, grad: PyTree, - value: float, + value: jnp.ndarray, params: PyTree, state: PyTree, ) -> PyTree: diff --git a/src/invrs_opt/optimizers/lbfgsb.py b/src/invrs_opt/optimizers/lbfgsb.py index c3251a8..dd606f3 100644 --- a/src/invrs_opt/optimizers/lbfgsb.py +++ b/src/invrs_opt/optimizers/lbfgsb.py @@ -335,7 +335,7 @@ def params_fn(state: LbfgsbState) -> PyTree: def update_fn( *, grad: PyTree, - value: float, + value: jnp.ndarray, params: PyTree, state: LbfgsbState, ) -> LbfgsbState: @@ -349,12 +349,14 @@ def _update_pure( ) -> Tuple[PyTree, NumpyLbfgsbDict]: assert onp.size(value) == 1 scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state) + flat_latent_params = scipy_lbfgsb_state.x.copy() scipy_lbfgsb_state.update( grad=onp.array(flat_latent_grad, dtype=onp.float64), value=onp.array(value, dtype=onp.float64), ) - flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x) - return flat_latent_params, scipy_lbfgsb_state.to_dict() + updated_flat_latent_params = scipy_lbfgsb_state.x + flat_latent_updates = updated_flat_latent_params - flat_latent_params + return flat_latent_updates, scipy_lbfgsb_state.to_dict() step, _, latent_params, jax_lbfgsb_state = state metadata, latents = param_base.partition_density_metadata(latent_params) @@ -395,16 +397,21 @@ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray: latents_grad ) # type: ignore[no-untyped-call] - flat_latents, jax_lbfgsb_state = jax.pure_callback( + flat_latent_updates, jax_lbfgsb_state = jax.pure_callback( _update_pure, (flat_latents_grad, jax_lbfgsb_state), flat_latents_grad, value, jax_lbfgsb_state, ) - latents = unflatten_fn(flat_latents) - latent_params = param_base.combine_density_metadata(metadata, latents) - latent_params = _update_parameterized_densities(latent_params, step) + latent_updates = unflatten_fn(flat_latent_updates) + latent_params = _apply_updates( + params=latent_params, + updates=param_base.combine_density_metadata(metadata, latent_updates), + value=value, + step=step, + ) + latent_params = _clip(latent_params) params = _params_from_latent_params(latent_params) return step + 1, params, latent_params, jax_lbfgsb_state @@ -433,15 +440,24 @@ def _leaf_params_from_latents(leaf: Any) -> Any: is_leaf=_is_parameterized_density, ) - def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree: - def _update_leaf(leaf: Any) -> Any: - if not _is_parameterized_density(leaf): - return leaf - return density_parameterization.update(leaf, step) + def _apply_updates( + params: PyTree, + updates: PyTree, + value: jnp.ndarray, + step: int, + ) -> PyTree: + def _leaf_apply_updates(update: Any, leaf: Any) -> Any: + if _is_parameterized_density(leaf): + return density_parameterization.update( + params=leaf, updates=update, value=value, step=step + ) + else: + return optax.apply_updates(params=leaf, updates=update) return tree_util.tree_map( - _update_leaf, - latent_params, + _leaf_apply_updates, + updates, + params, is_leaf=_is_parameterized_density, ) diff --git a/src/invrs_opt/optimizers/wrapped_optax.py b/src/invrs_opt/optimizers/wrapped_optax.py index 48606e4..ab08b5e 100644 --- a/src/invrs_opt/optimizers/wrapped_optax.py +++ b/src/invrs_opt/optimizers/wrapped_optax.py @@ -197,12 +197,12 @@ def params_fn(state: WrappedOptaxState) -> PyTree: def update_fn( *, grad: PyTree, - value: float, + value: jnp.ndarray, params: PyTree, state: WrappedOptaxState, ) -> WrappedOptaxState: """Updates the state.""" - del value, params + del params step, params, latent_params, opt_state = state metadata, latents = param_base.partition_density_metadata(latent_params) @@ -233,12 +233,14 @@ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray: lambda a, b: a + b, latents_grad, constraint_loss_grad ) - updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents) - latents = optax.apply_updates(params=latents, updates=updates) - - latent_params = param_base.combine_density_metadata(metadata, latents) + latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents) + latent_params = _apply_updates( + params=latent_params, + updates=param_base.combine_density_metadata(metadata, latent_updates), + value=value, + step=step, + ) latent_params = _clip(latent_params) - latent_params = _update_parameterized_densities(latent_params, step + 1) params = _params_from_latent_params(latent_params) return (step + 1, params, latent_params, opt_state) @@ -267,15 +269,24 @@ def _leaf_params_from_latents(leaf: Any) -> Any: is_leaf=_is_parameterized_density, ) - def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree: - def _update_leaf(leaf: Any) -> Any: - if not _is_parameterized_density(leaf): - return leaf - return density_parameterization.update(leaf, step) + def _apply_updates( + params: PyTree, + updates: PyTree, + value: jnp.ndarray, + step: int, + ) -> PyTree: + def _leaf_apply_updates(update: Any, leaf: Any) -> Any: + if _is_parameterized_density(leaf): + return density_parameterization.update( + params=leaf, updates=update, value=value, step=step + ) + else: + return optax.apply_updates(params=leaf, updates=update) return tree_util.tree_map( - _update_leaf, - latent_params, + _leaf_apply_updates, + updates, + params, is_leaf=_is_parameterized_density, ) diff --git a/src/invrs_opt/parameterization/base.py b/src/invrs_opt/parameterization/base.py index 28c2993..e436702 100644 --- a/src/invrs_opt/parameterization/base.py +++ b/src/invrs_opt/parameterization/base.py @@ -97,7 +97,13 @@ def __call__(self, params: PyTree) -> jnp.ndarray: class UpdateFn(Protocol): """Performs the required update of a parameterized density for the given step.""" - def __call__(self, params: PyTree, step: int) -> PyTree: + def __call__( + self, + params: PyTree, + updates: PyTree, + value: jnp.ndarray, + step: int, + ) -> PyTree: ... diff --git a/src/invrs_opt/parameterization/filter_project.py b/src/invrs_opt/parameterization/filter_project.py index 58d7a9d..6c692a6 100644 --- a/src/invrs_opt/parameterization/filter_project.py +++ b/src/invrs_opt/parameterization/filter_project.py @@ -115,10 +115,20 @@ def constraints_fn(params: FilterProjectParams) -> jnp.ndarray: del params return jnp.asarray(0.0) - def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams: + def update_fn( + params: FilterProjectParams, + updates: FilterProjectParams, + value: jnp.ndarray, + step: int, + ) -> FilterProjectParams: """Perform updates to `params` required for the given `step`.""" - del step - return params + del step, value + return FilterProjectParams( + latents=tree_util.tree_map( + lambda a, b: a + b, params.latents, updates.latents + ), + metadata=params.metadata, + ) return base.Density2DParameterization( to_density=to_density_fn, diff --git a/src/invrs_opt/parameterization/gaussian_levelset.py b/src/invrs_opt/parameterization/gaussian_levelset.py index d9d95b1..2362dfb 100644 --- a/src/invrs_opt/parameterization/gaussian_levelset.py +++ b/src/invrs_opt/parameterization/gaussian_levelset.py @@ -229,10 +229,20 @@ def constraints_fn( pad_pixels=pad_pixels, ) - def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetParams: + def update_fn( + params: GaussianLevelsetParams, + updates: GaussianLevelsetParams, + value: jnp.ndarray, + step: int, + ) -> GaussianLevelsetParams: """Perform updates to `params` required for the given `step`.""" - del step - return params + del step, value + return GaussianLevelsetParams( + latents=tree_util.tree_map( + lambda a, b: a + b, params.latents, updates.latents + ), + metadata=params.metadata, + ) return base.Density2DParameterization( to_density=to_density_fn, diff --git a/src/invrs_opt/parameterization/pixel.py b/src/invrs_opt/parameterization/pixel.py index 923d08b..676a50d 100644 --- a/src/invrs_opt/parameterization/pixel.py +++ b/src/invrs_opt/parameterization/pixel.py @@ -52,9 +52,20 @@ def constraints_fn(params: PixelParams) -> jnp.ndarray: del params return jnp.asarray(0.0) - def update_fn(params: PixelParams, step: int) -> PixelParams: - del step - return params + def update_fn( + params: PixelParams, + updates: PixelParams, + value: jnp.ndarray, + step: int, + ) -> PixelParams: + """Perform updates to `params` required for the given `step`.""" + del step, value + return PixelParams( + latents=tree_util.tree_map( + lambda a, b: a + b, params.latents, updates.latents + ), + metadata=params.metadata, + ) return base.Density2DParameterization( from_density=from_density_fn, diff --git a/tests/optimizers/test_lbfgsb.py b/tests/optimizers/test_lbfgsb.py index 614cf12..a114909 100644 --- a/tests/optimizers/test_lbfgsb.py +++ b/tests/optimizers/test_lbfgsb.py @@ -3,6 +3,7 @@ Copyright (c) 2023 The INVRS-IO authors. """ +import copy import unittest import jax @@ -691,8 +692,12 @@ def test_variable_parameterization(self): # at each step. p = filter_project.filter_project(beta=1) - def update_fn(params, step): - del step + _original_update_fn = copy.deepcopy(p.update) + + def update_fn(step, params, value, updates): + params = _original_update_fn( + step=step, params=params, value=value, updates=updates + ) params.metadata.beta += 1 return params diff --git a/tests/optimizers/test_wrapped_optax.py b/tests/optimizers/test_wrapped_optax.py index 8dcc75c..fa285cb 100644 --- a/tests/optimizers/test_wrapped_optax.py +++ b/tests/optimizers/test_wrapped_optax.py @@ -3,6 +3,7 @@ Copyright (c) 2023 The INVRS-IO authors. """ +import copy import dataclasses import unittest @@ -148,7 +149,7 @@ def loss_fn(params): onp.testing.assert_array_equal(values, expected_values) - def test_trajectory_matches_scipy_density_2d(self): + def test_trajectory_matches_optax_density_2d(self): initial_params = { "a": jnp.asarray([1.0, 2.0]), "b": types.BoundedArray( @@ -391,8 +392,12 @@ def test_variable_parameterization(self): # at each step. p = filter_project.filter_project(beta=1) - def update_fn(params, step): - del step + _original_update_fn = copy.deepcopy(p.update) + + def update_fn(step, params, value, updates): + params = _original_update_fn( + step=step, params=params, value=value, updates=updates + ) params.metadata.beta += 1 return params