Skip to content

Commit

Permalink
Add test and expand readme
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schubert authored and Martin Schubert committed Oct 25, 2023
1 parent 1070a84 commit 9aea137
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 88 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
# invrs-opt - Algorithms for inverse design
# invrs-opt - Optimization algorithms

## Overview

The `invrs-opt` package defines an optimizer API and (currently) implements the L-BFGS-B optimization algorithm along with some variants. The API is intended to be general so that new algorithms can be accommodated, and is inspired by the functional optimizer approach used in jax. Example usage is as follows:

```python
initial_params = ...

optimizer = invrs_opt.lbfgsb()
state = optimizer.init()

for _ in range(steps):
params = optimizer.params(state)
value, grad = jax.value_and_grad(loss_fn)(params)
state = optimizer.update(grad=grad, value=value, params=params, state=state)
```

Optimizers in `invrs-opt` are compatible with custom types defined in the [totypes](https://github.com/invrs-io/totypes) package. The basic `lbfgsb` optimizer enforces bounds for custom types, while the `density_lbfgsb` optimizer implements a filter-and-threshold operation for `DensityArray2D` types to ensure that solutions have the correct length scale.

## Install
```
Expand Down
36 changes: 32 additions & 4 deletions src/invrs_opt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,48 @@
"""

import dataclasses
from typing import Any, Callable
from typing import Any, Protocol

from totypes import json_utils

PyTree = Any


class InitFn(Protocol):
"""Callable which initializes an optimizer state."""

def __call__(self, params: PyTree) -> PyTree:
...


class ParamsFn(Protocol):
"""Callable which returns the parameters for an optimizer state."""

def __call__(self, state: PyTree) -> PyTree:
...


class UpdateFn(Protocol):
"""Callable which updates an optimizer state."""

def __call__(
self,
*,
grad: PyTree,
value: float,
params: PyTree,
state: PyTree,
) -> PyTree:
...


@dataclasses.dataclass
class Optimizer:
"""Stores the `(init, params, update)` function triple for an optimizer."""

init: Callable[[PyTree], PyTree]
params: Callable[[PyTree], PyTree]
update: Callable[[PyTree, float, PyTree, PyTree], PyTree]
init: InitFn
params: ParamsFn
update: UpdateFn


# Additional custom types and prefixes used for serializing optimizer state.
Expand Down
106 changes: 68 additions & 38 deletions src/invrs_opt/lbfgsb/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

# Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
MAXCOR_MAX_VALUE = 100
MAXCOR_DEFAULT = 20
LINE_SEARCH_MAX_STEPS_DEFAULT = 100

# Maps bound scenarios to integers.
BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
Expand All @@ -49,10 +51,47 @@


def lbfgsb(
maxcor: int,
line_search_max_steps: int,
maxcor: int = MAXCOR_DEFAULT,
line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
) -> base.Optimizer:
"""Return an optimizer implementing the standard L-BFGS-B algorithm."""
"""Return an optimizer implementing the standard L-BFGS-B algorithm.
This optimizer wraps scipy's implementation of the algorithm, and provides
a jax-style API to the scheme. The optimizer works with custom types such
as the `BoundedArray` to constrain the optimization variable.
Example usage is as follows:
def fn(x):
leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
return jnp.sum(jnp.asarray(leaves_sum_sq))
x0 = {
"a": jnp.ones((3,)),
"b": BoundedArray(
value=-jnp.ones((2, 5)),
lower_bound=-5,
upper_bound=5,
),
}
opt = lbfgsb(maxcor=20, line_search_max_steps=100)
state = opt.init(x0)
for _ in range(10):
x = opt.params(state)
value, grad = jax.value_and_grad(fn)(x)
state = opt.update(grad, value, state)
While the algorithm can work with pytrees of jax arrays, numpy arrays can
also be used. Thus, e.g. the optimizer can directly be used with autograd.
Args:
maxcor: The maximum number of variable metric corrections used to define
the limited memory matrix, in the L-BFGS-B scheme.
line_search_max_steps: The maximum number of steps in the line search.
Returns:
The `base.Optimizer`.
"""
return transformed_lbfgsb(
maxcor=maxcor,
line_search_max_steps=line_search_max_steps,
Expand All @@ -61,11 +100,30 @@ def lbfgsb(


def density_lbfgsb(
maxcor: int,
line_search_max_steps: int,
beta: float,
maxcor: int = MAXCOR_DEFAULT,
line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
) -> base.Optimizer:
"""Return an L-BFGS-B optimizer with additional transforms for density arrays."""
"""Return an L-BFGS-B optimizer with additional transforms for density arrays.
Parameters that are of type `DensityArray2D` are represented as latent parameters
that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
where the kernel has a full-width at half-maximum determined by the minimum width
and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
density is scaled before the transform is applied, and then unscaled afterwards.
Args:
beta: Determines the steepness of the thresholding.
maxcor: The maximum number of variable metric corrections used to define
the limited memory matrix, in the L-BFGS-B scheme.
line_search_max_steps: The maximum number of steps in the line search.
Returns:
The `base.Optimizer`.
"""

def transform_fn(tree: PyTree) -> PyTree:
return tree_util.tree_map(
Expand Down Expand Up @@ -98,41 +156,13 @@ def transformed_lbfgsb(
line_search_max_steps: int,
transform_fn: Callable[[PyTree], PyTree],
) -> base.Optimizer:
"""Construct an optimizer implementing the L-BFGS-B algorithm.
"""Construct an latent parameter L-BFGS-B optimizer.
The optimized parameters are termed latent parameters, from which the
actual parameters returned by the optimizer are obtained using the
`transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
the identity), this is equivalent to the standard L-BFGS-B algorithm.
This optimizer wraps scipy's implementation of the algorithm, and provides
a jax-style API to the scheme. The optimizer works with custom types such
as the `BoundedArray` to constrain the optimization variable.
Example usage is as follows:
def fn(x):
leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
return jnp.sum(jnp.asarray(leaves_sum_sq))
x0 = {
"a": jnp.ones((3,)),
"b": BoundedArray(
value=-jnp.ones((2, 5)),
lower_bound=-5,
upper_bound=5,
),
}
opt = lbfgsb(maxcor=20, line_search_max_steps=100)
state = opt.init(x0)
for _ in range(10):
x = opt.params(state)
value, grad = jax.value_and_grad(fn)(x)
state = opt.update(grad, value, state)
While the algorithm can work with pytrees of jax arrays, numpy arrays can
also be used. Thus, e.g. the optimizer can directly be used with autograd.
Args:
maxcor: The maximum number of variable metric corrections used to define
the limited memory matrix, in the L-BFGS-B scheme.
Expand Down Expand Up @@ -204,7 +234,7 @@ def update_fn(
return base.Optimizer(
init=init_fn,
params=params_fn,
update=update_fn, # type: ignore[arg-type]
update=update_fn,
)


Expand Down Expand Up @@ -503,7 +533,7 @@ def _configure_bounds(
lower_bound_array = [0.0 if x is None else x for x in lower_bound]
upper_bound_array = [0.0 if x is None else x for x in upper_bound]
return (
onp.asarray(lower_bound_array),
onp.asarray(upper_bound_array),
onp.asarray(lower_bound_array, onp.float64),
onp.asarray(upper_bound_array, onp.float64),
onp.asarray(bound_type),
)
64 changes: 32 additions & 32 deletions src/invrs_opt/lbfgsb/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,6 @@
GAUSSIAN_FWHM_SIZE_MULTIPLE: float = 3.0


def normalized_array_from_density(density: types.Density2DArray) -> jnp.ndarray:
"""Returns an array with values scaled to the range `(-1, 1)`."""
value_mid = (density.upper_bound + density.lower_bound) / 2
value_range = density.upper_bound - density.lower_bound
return jnp.asarray(2 * (density.array - value_mid) / value_range)


def rescale_array_for_density(
array: jnp.ndarray,
density: types.Density2DArray,
) -> jnp.ndarray:
"""Rescales an array for the bounds defined by `density`."""
value_mid = (density.upper_bound + density.lower_bound) / 2
value_range = density.upper_bound - density.lower_bound
return array / 2 * value_range + value_mid


def apply_fixed_pixels(density: types.Density2DArray) -> types.Density2DArray:
"""Set fixed pixels with their required values."""
fixed_solid = density.fixed_solid
fixed_void = density.fixed_void
(array,), treedef = tree_util.tree_flatten(density)
if fixed_solid is not None:
array = jnp.where(fixed_solid, density.upper_bound, array)
if fixed_void is not None:
array = jnp.where(fixed_void, density.lower_bound, array)
transformed_density: types.Density2DArray = tree_util.tree_unflatten(
treedef, (array,)
)
return transformed_density


def density_gaussian_filter_and_tanh(
density: types.Density2DArray,
beta: float,
Expand Down Expand Up @@ -95,6 +63,38 @@ def density_gaussian_filter_and_tanh(
return transformed_density


def normalized_array_from_density(density: types.Density2DArray) -> jnp.ndarray:
"""Returns an array with values scaled to the range `(-1, 1)`."""
value_mid = (density.upper_bound + density.lower_bound) / 2
value_range = density.upper_bound - density.lower_bound
return jnp.asarray(2 * (density.array - value_mid) / value_range)


def rescale_array_for_density(
array: jnp.ndarray,
density: types.Density2DArray,
) -> jnp.ndarray:
"""Rescales an array for the bounds defined by `density`."""
value_mid = (density.upper_bound + density.lower_bound) / 2
value_range = density.upper_bound - density.lower_bound
return array / 2 * value_range + value_mid


def apply_fixed_pixels(density: types.Density2DArray) -> types.Density2DArray:
"""Set fixed pixels with their required values."""
fixed_solid = density.fixed_solid
fixed_void = density.fixed_void
(array,), treedef = tree_util.tree_flatten(density)
if fixed_solid is not None:
array = jnp.where(fixed_solid, density.upper_bound, array)
if fixed_void is not None:
array = jnp.where(fixed_void, density.lower_bound, array)
transformed_density: types.Density2DArray = tree_util.tree_unflatten(
treedef, (array,)
)
return transformed_density


def conv(x: jnp.ndarray, kernel: jnp.ndarray, padding: str) -> jnp.ndarray:
"""Convolves `x` with `kernel`."""
assert x.ndim == 4
Expand Down
54 changes: 50 additions & 4 deletions tests/lbfgsb/test_lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,66 @@
import jax
import jax.numpy as jnp
import numpy as onp
import parameterized
from parameterized import parameterized
import scipy.optimize as spo

from invrs_opt.lbfgsb import lbfgsb
from totypes import types


class DensityLbfgsbBoundsTest(unittest.TestCase):
@parameterized.expand([[-1, 1, 1], [-1, 1, -1], [0, 1, 1], [0, 1, -1]])
def test_respects_bounds(self, lower_bound, upper_bound, sign):
def loss_fn(density):
return sign * jnp.sum(density.array)

params = types.Density2DArray(
array=jnp.ones((5, 5)) * (lower_bound + upper_bound) / 2,
lower_bound=lower_bound,
upper_bound=upper_bound,
)
opt = lbfgsb.density_lbfgsb(beta=2)
state = opt.init(params)
for _ in range(10):
params = opt.params(state)
value, grad = jax.value_and_grad(loss_fn)(params)
state = opt.update(grad=grad, value=value, params=params, state=state)

params = opt.params(state)
expected = upper_bound if sign < 0 else lower_bound
onp.testing.assert_allclose(params.array, expected)


class LbfgsbBoundsTest(unittest.TestCase):
@parameterized.expand([[-1, 1, 1], [-1, 1, -1], [0, 1, 1], [0, 1, -1]])
def test_respects_bounds(self, lower_bound, upper_bound, sign):
def loss_fn(density):
return sign * jnp.sum(density.array)

params = types.Density2DArray(
array=jnp.ones((5, 5)) * (lower_bound + upper_bound) / 2,
lower_bound=lower_bound,
upper_bound=upper_bound,
)
opt = lbfgsb.lbfgsb()
state = opt.init(params)
for _ in range(10):
params = opt.params(state)
value, grad = jax.value_and_grad(loss_fn)(params)
state = opt.update(grad=grad, value=value, params=params, state=state)

params = opt.params(state)
expected = upper_bound if sign < 0 else lower_bound
onp.testing.assert_allclose(params.array, expected)


class LbfgsbInputValidationTest(unittest.TestCase):
@parameterized.parameterized.expand([[0], [-1], [500], ["abc"]])
@parameterized.expand([[0], [-1], [500], ["abc"]])
def test_maxcor_validation(self, invalid_maxcor):
with self.assertRaisesRegex(ValueError, "`maxcor` must be greater than 0"):
lbfgsb.lbfgsb(maxcor=invalid_maxcor, line_search_max_steps=100)

@parameterized.parameterized.expand([[0], [-1], ["abc"]])
@parameterized.expand([[0], [-1], ["abc"]])
def test_line_search_max_steps_validation(self, invalid_line_search_max_steps):
with self.assertRaisesRegex(ValueError, "`line_search_max_steps` must be "):
lbfgsb.lbfgsb(
Expand Down Expand Up @@ -183,7 +229,7 @@ def loss_fn(x):
# the other we are using float32.
onp.testing.assert_allclose(scipy_values[:10], wrapper_values[:10], rtol=1e-6)

@parameterized.parameterized.expand(
@parameterized.expand(
[
[2.0],
[jnp.ones((3,))],
Expand Down
Loading

0 comments on commit 9aea137

Please sign in to comment.