Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry debug #39

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.9.2"
current_version = "v0.9.3"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# invrs-opt - Optimization algorithms for inverse design
`v0.9.2`
`v0.9.3`

## Overview

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_opt"
version = "v0.9.2"
version = "v0.9.3"
description = "Algorithms for inverse design"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand All @@ -16,7 +16,7 @@ maintainers = [
]

dependencies = [
"jax < 0.4.32",
"jax <= 0.4.35",
"jaxlib",
"numpy",
"requests",
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v0.9.2"
__version__ = "v0.9.3"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_opt import parameterization as parameterization
Expand Down
53 changes: 26 additions & 27 deletions src/invrs_opt/optimizers/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

import copy
import dataclasses
from typing import Any, Dict, Optional, Sequence, Tuple, Union

Expand All @@ -28,6 +27,7 @@
NDArray = onp.ndarray[Any, Any]
PyTree = Any
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
NumpyLbfgsbDict = Dict[str, NDArray]
JaxLbfgsbDict = Dict[str, jnp.ndarray]
LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]

Expand Down Expand Up @@ -299,7 +299,7 @@ def parameterized_lbfgsb(
def init_fn(params: PyTree) -> LbfgsbState:
"""Initializes the optimization state."""

def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, NumpyLbfgsbDict]:
lower_bound = types.extract_lower_bound(latent_params)
upper_bound = types.extract_upper_bound(latent_params)
scipy_lbfgsb_state = ScipyLbfgsbState.init(
Expand All @@ -312,7 +312,7 @@ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
gtol=gtol,
)
latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
return latent_params, scipy_lbfgsb_state.to_jax()
return latent_params, scipy_lbfgsb_state.to_dict()

latent_params = _init_latents(params)
metadata, latents = param_base.partition_density_metadata(latent_params)
Expand Down Expand Up @@ -346,15 +346,15 @@ def _update_pure(
flat_latent_grad: PyTree,
value: jnp.ndarray,
jax_lbfgsb_state: JaxLbfgsbDict,
) -> Tuple[PyTree, JaxLbfgsbDict]:
) -> Tuple[PyTree, NumpyLbfgsbDict]:
assert onp.size(value) == 1
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
scipy_lbfgsb_state.update(
grad=onp.array(flat_latent_grad, dtype=onp.float64),
value=onp.array(value, dtype=onp.float64),
)
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
return flat_latent_params, scipy_lbfgsb_state.to_jax()
return flat_latent_params, scipy_lbfgsb_state.to_dict()

step, _, latent_params, jax_lbfgsb_state = state
metadata, latents = param_base.partition_density_metadata(latent_params)
Expand Down Expand Up @@ -696,31 +696,30 @@ def __post_init__(self) -> None:
_validate_array_dtype(self._upper_bound, onp.float64)
_validate_array_dtype(self._bound_type, int)

def to_jax(self) -> Dict[str, jnp.ndarray]:
def to_dict(self) -> NumpyLbfgsbDict:
"""Generates a dictionary of jax arrays defining the state."""
return dict(
x=jnp.asarray(self.x),
converged=jnp.asarray(self.converged),
_maxcor=jnp.asarray(self._maxcor),
_line_search_max_steps=jnp.asarray(self._line_search_max_steps),
_ftol=jnp.asarray(self._ftol),
_gtol=jnp.asarray(self._gtol),
_wa=jnp.asarray(self._wa),
_iwa=jnp.asarray(self._iwa),
x=onp.asarray(self.x),
converged=onp.asarray(self.converged),
_maxcor=onp.asarray(self._maxcor),
_line_search_max_steps=onp.asarray(self._line_search_max_steps),
_ftol=onp.asarray(self._ftol),
_gtol=onp.asarray(self._gtol),
_wa=onp.asarray(self._wa),
_iwa=onp.asarray(self._iwa),
_task=_array_from_s60_str(self._task),
_csave=_array_from_s60_str(self._csave),
_lsave=jnp.asarray(self._lsave),
_isave=jnp.asarray(self._isave),
_dsave=jnp.asarray(self._dsave),
_lower_bound=jnp.asarray(self._lower_bound),
_upper_bound=jnp.asarray(self._upper_bound),
_bound_type=jnp.asarray(self._bound_type),
_lsave=onp.asarray(self._lsave),
_isave=onp.asarray(self._isave),
_dsave=onp.asarray(self._dsave),
_lower_bound=onp.asarray(self._lower_bound),
_upper_bound=onp.asarray(self._upper_bound),
_bound_type=onp.asarray(self._bound_type),
)

@classmethod
def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
def from_jax(cls, state_dict: JaxLbfgsbDict) -> "ScipyLbfgsbState":
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
state_dict = copy.deepcopy(state_dict)
return ScipyLbfgsbState(
x=onp.array(state_dict["x"], dtype=onp.float64),
converged=onp.asarray(state_dict["converged"], dtype=bool),
Expand All @@ -730,8 +729,8 @@ def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
_gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
_wa=onp.array(state_dict["_wa"], onp.float64),
_iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
_task=_s60_str_from_array(state_dict["_task"]),
_csave=_s60_str_from_array(state_dict["_csave"]),
_task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
_csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
_lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
_isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
_dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
Expand Down Expand Up @@ -898,15 +897,15 @@ def _configure_bounds(
)


def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
def _array_from_s60_str(s60_str: NDArray) -> NDArray:
"""Return a jax array for a numpy s60 string."""
assert s60_str.shape == (1,)
chars = [int(o) for o in s60_str[0]]
chars.extend([32] * (59 - len(chars)))
return jnp.asarray(chars, dtype=int)
return onp.asarray(chars, dtype=int)


def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
def _s60_str_from_array(array: NDArray) -> NDArray:
"""Return a numpy s60 string for a jax array."""
return onp.asarray(
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
Expand Down