From 9aea137d3a003c0853a8cf296083dd9f2e3cd87d Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Wed, 25 Oct 2023 10:27:47 -0700 Subject: [PATCH] Add test and expand readme --- README.md | 20 ++++- src/invrs_opt/base.py | 36 +++++++- src/invrs_opt/lbfgsb/lbfgsb.py | 106 ++++++++++++++-------- src/invrs_opt/lbfgsb/transform.py | 64 ++++++------- tests/lbfgsb/test_lbfgsb.py | 54 ++++++++++- tests/lbfgsb/test_transform.py | 145 ++++++++++++++++++++++++++++-- 6 files changed, 337 insertions(+), 88 deletions(-) diff --git a/README.md b/README.md index e5b3960..3596e77 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,22 @@ -# invrs-opt - Algorithms for inverse design +# invrs-opt - Optimization algorithms + +## Overview + +The `invrs-opt` package defines an optimizer API and (currently) implements the L-BFGS-B optimization algorithm along with some variants. The API is intended to be general so that new algorithms can be accommodated, and is inspired by the functional optimizer approach used in jax. Example usage is as follows: + +```python +initial_params = ... + +optimizer = invrs_opt.lbfgsb() +state = optimizer.init() + +for _ in range(steps): + params = optimizer.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = optimizer.update(grad=grad, value=value, params=params, state=state) +``` + +Optimizers in `invrs-opt` are compatible with custom types defined in the [totypes](https://github.com/invrs-io/totypes) package. The basic `lbfgsb` optimizer enforces bounds for custom types, while the `density_lbfgsb` optimizer implements a filter-and-threshold operation for `DensityArray2D` types to ensure that solutions have the correct length scale. ## Install ``` diff --git a/src/invrs_opt/base.py b/src/invrs_opt/base.py index 4828c26..13b3dcb 100644 --- a/src/invrs_opt/base.py +++ b/src/invrs_opt/base.py @@ -4,20 +4,48 @@ """ import dataclasses -from typing import Any, Callable +from typing import Any, Protocol from totypes import json_utils PyTree = Any +class InitFn(Protocol): + """Callable which initializes an optimizer state.""" + + def __call__(self, params: PyTree) -> PyTree: + ... + + +class ParamsFn(Protocol): + """Callable which returns the parameters for an optimizer state.""" + + def __call__(self, state: PyTree) -> PyTree: + ... + + +class UpdateFn(Protocol): + """Callable which updates an optimizer state.""" + + def __call__( + self, + *, + grad: PyTree, + value: float, + params: PyTree, + state: PyTree, + ) -> PyTree: + ... + + @dataclasses.dataclass class Optimizer: """Stores the `(init, params, update)` function triple for an optimizer.""" - init: Callable[[PyTree], PyTree] - params: Callable[[PyTree], PyTree] - update: Callable[[PyTree, float, PyTree, PyTree], PyTree] + init: InitFn + params: ParamsFn + update: UpdateFn # Additional custom types and prefixes used for serializing optimizer state. diff --git a/src/invrs_opt/lbfgsb/lbfgsb.py b/src/invrs_opt/lbfgsb/lbfgsb.py index 4959a6e..8be71eb 100644 --- a/src/invrs_opt/lbfgsb/lbfgsb.py +++ b/src/invrs_opt/lbfgsb/lbfgsb.py @@ -36,6 +36,8 @@ # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme. MAXCOR_MAX_VALUE = 100 +MAXCOR_DEFAULT = 20 +LINE_SEARCH_MAX_STEPS_DEFAULT = 100 # Maps bound scenarios to integers. BOUNDS_MAP: Dict[Tuple[bool, bool], int] = { @@ -49,10 +51,47 @@ def lbfgsb( - maxcor: int, - line_search_max_steps: int, + maxcor: int = MAXCOR_DEFAULT, + line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT, ) -> base.Optimizer: - """Return an optimizer implementing the standard L-BFGS-B algorithm.""" + """Return an optimizer implementing the standard L-BFGS-B algorithm. + + This optimizer wraps scipy's implementation of the algorithm, and provides + a jax-style API to the scheme. The optimizer works with custom types such + as the `BoundedArray` to constrain the optimization variable. + + Example usage is as follows: + + def fn(x): + leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)] + return jnp.sum(jnp.asarray(leaves_sum_sq)) + + x0 = { + "a": jnp.ones((3,)), + "b": BoundedArray( + value=-jnp.ones((2, 5)), + lower_bound=-5, + upper_bound=5, + ), + } + opt = lbfgsb(maxcor=20, line_search_max_steps=100) + state = opt.init(x0) + for _ in range(10): + x = opt.params(state) + value, grad = jax.value_and_grad(fn)(x) + state = opt.update(grad, value, state) + + While the algorithm can work with pytrees of jax arrays, numpy arrays can + also be used. Thus, e.g. the optimizer can directly be used with autograd. + + Args: + maxcor: The maximum number of variable metric corrections used to define + the limited memory matrix, in the L-BFGS-B scheme. + line_search_max_steps: The maximum number of steps in the line search. + + Returns: + The `base.Optimizer`. + """ return transformed_lbfgsb( maxcor=maxcor, line_search_max_steps=line_search_max_steps, @@ -61,11 +100,30 @@ def lbfgsb( def density_lbfgsb( - maxcor: int, - line_search_max_steps: int, beta: float, + maxcor: int = MAXCOR_DEFAULT, + line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT, ) -> base.Optimizer: - """Return an L-BFGS-B optimizer with additional transforms for density arrays.""" + """Return an L-BFGS-B optimizer with additional transforms for density arrays. + + Parameters that are of type `DensityArray2D` are represented as latent parameters + that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by, + + transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta) + + where the kernel has a full-width at half-maximum determined by the minimum width + and spacing parameters of the `DensityArray2D`. Where the bounds differ, the + density is scaled before the transform is applied, and then unscaled afterwards. + + Args: + beta: Determines the steepness of the thresholding. + maxcor: The maximum number of variable metric corrections used to define + the limited memory matrix, in the L-BFGS-B scheme. + line_search_max_steps: The maximum number of steps in the line search. + + Returns: + The `base.Optimizer`. + """ def transform_fn(tree: PyTree) -> PyTree: return tree_util.tree_map( @@ -98,41 +156,13 @@ def transformed_lbfgsb( line_search_max_steps: int, transform_fn: Callable[[PyTree], PyTree], ) -> base.Optimizer: - """Construct an optimizer implementing the L-BFGS-B algorithm. + """Construct an latent parameter L-BFGS-B optimizer. The optimized parameters are termed latent parameters, from which the actual parameters returned by the optimizer are obtained using the `transform_fn`. In the simple case where this is just `lambda x: x` (i.e. the identity), this is equivalent to the standard L-BFGS-B algorithm. - This optimizer wraps scipy's implementation of the algorithm, and provides - a jax-style API to the scheme. The optimizer works with custom types such - as the `BoundedArray` to constrain the optimization variable. - - Example usage is as follows: - - def fn(x): - leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)] - return jnp.sum(jnp.asarray(leaves_sum_sq)) - - x0 = { - "a": jnp.ones((3,)), - "b": BoundedArray( - value=-jnp.ones((2, 5)), - lower_bound=-5, - upper_bound=5, - ), - } - opt = lbfgsb(maxcor=20, line_search_max_steps=100) - state = opt.init(x0) - for _ in range(10): - x = opt.params(state) - value, grad = jax.value_and_grad(fn)(x) - state = opt.update(grad, value, state) - - While the algorithm can work with pytrees of jax arrays, numpy arrays can - also be used. Thus, e.g. the optimizer can directly be used with autograd. - Args: maxcor: The maximum number of variable metric corrections used to define the limited memory matrix, in the L-BFGS-B scheme. @@ -204,7 +234,7 @@ def update_fn( return base.Optimizer( init=init_fn, params=params_fn, - update=update_fn, # type: ignore[arg-type] + update=update_fn, ) @@ -503,7 +533,7 @@ def _configure_bounds( lower_bound_array = [0.0 if x is None else x for x in lower_bound] upper_bound_array = [0.0 if x is None else x for x in upper_bound] return ( - onp.asarray(lower_bound_array), - onp.asarray(upper_bound_array), + onp.asarray(lower_bound_array, onp.float64), + onp.asarray(upper_bound_array, onp.float64), onp.asarray(bound_type), ) diff --git a/src/invrs_opt/lbfgsb/transform.py b/src/invrs_opt/lbfgsb/transform.py index 76bbd5f..07e22ac 100644 --- a/src/invrs_opt/lbfgsb/transform.py +++ b/src/invrs_opt/lbfgsb/transform.py @@ -16,38 +16,6 @@ GAUSSIAN_FWHM_SIZE_MULTIPLE: float = 3.0 -def normalized_array_from_density(density: types.Density2DArray) -> jnp.ndarray: - """Returns an array with values scaled to the range `(-1, 1)`.""" - value_mid = (density.upper_bound + density.lower_bound) / 2 - value_range = density.upper_bound - density.lower_bound - return jnp.asarray(2 * (density.array - value_mid) / value_range) - - -def rescale_array_for_density( - array: jnp.ndarray, - density: types.Density2DArray, -) -> jnp.ndarray: - """Rescales an array for the bounds defined by `density`.""" - value_mid = (density.upper_bound + density.lower_bound) / 2 - value_range = density.upper_bound - density.lower_bound - return array / 2 * value_range + value_mid - - -def apply_fixed_pixels(density: types.Density2DArray) -> types.Density2DArray: - """Set fixed pixels with their required values.""" - fixed_solid = density.fixed_solid - fixed_void = density.fixed_void - (array,), treedef = tree_util.tree_flatten(density) - if fixed_solid is not None: - array = jnp.where(fixed_solid, density.upper_bound, array) - if fixed_void is not None: - array = jnp.where(fixed_void, density.lower_bound, array) - transformed_density: types.Density2DArray = tree_util.tree_unflatten( - treedef, (array,) - ) - return transformed_density - - def density_gaussian_filter_and_tanh( density: types.Density2DArray, beta: float, @@ -95,6 +63,38 @@ def density_gaussian_filter_and_tanh( return transformed_density +def normalized_array_from_density(density: types.Density2DArray) -> jnp.ndarray: + """Returns an array with values scaled to the range `(-1, 1)`.""" + value_mid = (density.upper_bound + density.lower_bound) / 2 + value_range = density.upper_bound - density.lower_bound + return jnp.asarray(2 * (density.array - value_mid) / value_range) + + +def rescale_array_for_density( + array: jnp.ndarray, + density: types.Density2DArray, +) -> jnp.ndarray: + """Rescales an array for the bounds defined by `density`.""" + value_mid = (density.upper_bound + density.lower_bound) / 2 + value_range = density.upper_bound - density.lower_bound + return array / 2 * value_range + value_mid + + +def apply_fixed_pixels(density: types.Density2DArray) -> types.Density2DArray: + """Set fixed pixels with their required values.""" + fixed_solid = density.fixed_solid + fixed_void = density.fixed_void + (array,), treedef = tree_util.tree_flatten(density) + if fixed_solid is not None: + array = jnp.where(fixed_solid, density.upper_bound, array) + if fixed_void is not None: + array = jnp.where(fixed_void, density.lower_bound, array) + transformed_density: types.Density2DArray = tree_util.tree_unflatten( + treedef, (array,) + ) + return transformed_density + + def conv(x: jnp.ndarray, kernel: jnp.ndarray, padding: str) -> jnp.ndarray: """Convolves `x` with `kernel`.""" assert x.ndim == 4 diff --git a/tests/lbfgsb/test_lbfgsb.py b/tests/lbfgsb/test_lbfgsb.py index 3f8c2fc..598ff23 100644 --- a/tests/lbfgsb/test_lbfgsb.py +++ b/tests/lbfgsb/test_lbfgsb.py @@ -8,20 +8,66 @@ import jax import jax.numpy as jnp import numpy as onp -import parameterized +from parameterized import parameterized import scipy.optimize as spo from invrs_opt.lbfgsb import lbfgsb from totypes import types +class DensityLbfgsbBoundsTest(unittest.TestCase): + @parameterized.expand([[-1, 1, 1], [-1, 1, -1], [0, 1, 1], [0, 1, -1]]) + def test_respects_bounds(self, lower_bound, upper_bound, sign): + def loss_fn(density): + return sign * jnp.sum(density.array) + + params = types.Density2DArray( + array=jnp.ones((5, 5)) * (lower_bound + upper_bound) / 2, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + opt = lbfgsb.density_lbfgsb(beta=2) + state = opt.init(params) + for _ in range(10): + params = opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = opt.update(grad=grad, value=value, params=params, state=state) + + params = opt.params(state) + expected = upper_bound if sign < 0 else lower_bound + onp.testing.assert_allclose(params.array, expected) + + +class LbfgsbBoundsTest(unittest.TestCase): + @parameterized.expand([[-1, 1, 1], [-1, 1, -1], [0, 1, 1], [0, 1, -1]]) + def test_respects_bounds(self, lower_bound, upper_bound, sign): + def loss_fn(density): + return sign * jnp.sum(density.array) + + params = types.Density2DArray( + array=jnp.ones((5, 5)) * (lower_bound + upper_bound) / 2, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + opt = lbfgsb.lbfgsb() + state = opt.init(params) + for _ in range(10): + params = opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = opt.update(grad=grad, value=value, params=params, state=state) + + params = opt.params(state) + expected = upper_bound if sign < 0 else lower_bound + onp.testing.assert_allclose(params.array, expected) + + class LbfgsbInputValidationTest(unittest.TestCase): - @parameterized.parameterized.expand([[0], [-1], [500], ["abc"]]) + @parameterized.expand([[0], [-1], [500], ["abc"]]) def test_maxcor_validation(self, invalid_maxcor): with self.assertRaisesRegex(ValueError, "`maxcor` must be greater than 0"): lbfgsb.lbfgsb(maxcor=invalid_maxcor, line_search_max_steps=100) - @parameterized.parameterized.expand([[0], [-1], ["abc"]]) + @parameterized.expand([[0], [-1], ["abc"]]) def test_line_search_max_steps_validation(self, invalid_line_search_max_steps): with self.assertRaisesRegex(ValueError, "`line_search_max_steps` must be "): lbfgsb.lbfgsb( @@ -183,7 +229,7 @@ def loss_fn(x): # the other we are using float32. onp.testing.assert_allclose(scipy_values[:10], wrapper_values[:10], rtol=1e-6) - @parameterized.parameterized.expand( + @parameterized.expand( [ [2.0], [jnp.ones((3,))], diff --git a/tests/lbfgsb/test_transform.py b/tests/lbfgsb/test_transform.py index 4225837..299382c 100644 --- a/tests/lbfgsb/test_transform.py +++ b/tests/lbfgsb/test_transform.py @@ -3,6 +3,7 @@ Copyright (c) 2023 Martin F. Schubert """ +import dataclasses import unittest import jax @@ -13,15 +14,141 @@ from invrs_opt.lbfgsb import transform from totypes import types -TEST_KERNEL = onp.asarray( # Kernel is intentionally asymmetric. - [ - [0, 1, 1, 0, 0], - [1, 1, 1, 1, 1], - [0, 1, 1, 1, 1], - [0, 0, 1, 0, 0], - ], - dtype=bool, -) + +class GaussianFilterTest(unittest.TestCase): + @parameterized.expand([[1, 5], [3, 3], [5, 1]]) + def test_transformed_matches_expected(self, minimum_width, minimum_spacing): + array = onp.zeros((9, 9)) + array[4, 4] = 9 + density = types.Density2DArray( + array=array, + lower_bound=0, + upper_bound=1, + minimum_width=minimum_width, + minimum_spacing=minimum_spacing, + ) + beta = 1 + transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta) + expected = onp.asarray( + [ + [0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12], + [0.12, 0.12, 0.13, 0.14, 0.14, 0.14, 0.13, 0.12, 0.12], + [0.12, 0.13, 0.15, 0.22, 0.27, 0.22, 0.15, 0.13, 0.12], + [0.12, 0.14, 0.22, 0.48, 0.64, 0.48, 0.22, 0.14, 0.12], + [0.12, 0.14, 0.27, 0.64, 0.82, 0.64, 0.27, 0.14, 0.12], + [0.12, 0.14, 0.22, 0.48, 0.64, 0.48, 0.22, 0.14, 0.12], + [0.12, 0.13, 0.15, 0.22, 0.27, 0.22, 0.15, 0.13, 0.12], + [0.12, 0.12, 0.13, 0.14, 0.14, 0.14, 0.13, 0.12, 0.12], + [0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12], + ] + ) + onp.testing.assert_allclose(transformed.array, expected, rtol=0.05) + + @parameterized.expand([[1, 1], [3, 1], [5, 1], [10, 1], [10, 0.5], [10, 2]]) + def test_ones_density_yields_tanh_beta(self, length_scale, upper_bound): + array = onp.ones((20, 20)) * upper_bound + density = types.Density2DArray( + array=array, + lower_bound=0, + upper_bound=upper_bound, + minimum_width=length_scale, + minimum_spacing=length_scale, + ) + beta = 1 + transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta) + onp.testing.assert_allclose( + transformed.array, + (1 + onp.tanh(beta)) * 0.5 * upper_bound, + rtol=0.01, + ) + + def test_batch_matches_single(self): + beta = 4 + density = types.Density2DArray( + array=onp.arange(600).reshape((6, 10, 10)), + minimum_width=5, + minimum_spacing=5, + ) + transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta) + for i in range(6): + transformed_single = transform.density_gaussian_filter_and_tanh( + density=dataclasses.replace( + density, + array=density.array[i, :, :], + ), + beta=beta, + ) + onp.testing.assert_allclose( + transformed.array[i, :, :], transformed_single.array + ) + + def test_periodic(self): + beta = 100 + array = onp.zeros((5, 5)) + array[0, 0] = 9 + + # No periodicity. + density = types.Density2DArray( + array, + minimum_spacing=3, + minimum_width=3, + periodic=(False, False), + lower_bound=0, + upper_bound=1, + ) + transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta) + expected = onp.asarray( + [ + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + onp.testing.assert_allclose(transformed.array, expected, atol=0.01) + + # Periodic along the first axis. + density = dataclasses.replace(density, periodic=(True, False)) + transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta) + expected = onp.asarray( + [ + [1, 1, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + ] + ) + onp.testing.assert_allclose(transformed.array, expected, atol=0.01) + + # Periodic along the second axis. + density = dataclasses.replace(density, periodic=(False, True)) + transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta) + expected = onp.asarray( + [ + [1, 1, 1, 1, 1], + [1, 1, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + onp.testing.assert_allclose(transformed.array, expected, atol=0.01) + + # Periodic along both axes. + density = dataclasses.replace(density, periodic=(True, True)) + transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta) + expected = onp.asarray( + [ + [1, 1, 0, 0, 1], + [1, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ] + ) + onp.testing.assert_allclose(transformed.array, expected, atol=0.01) class RescaleTest(unittest.TestCase):