Skip to content

Commit

Permalink
Merge pull request #5 from invrs-io/loss
Browse files Browse the repository at this point in the history
Clean up loss module
  • Loading branch information
mfschubert authored Sep 25, 2023
2 parents d6e3804 + c203152 commit 2d7e4a2
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 47 deletions.
12 changes: 4 additions & 8 deletions .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4

- name: Checkout submodule
uses: actions/checkout@v4
with:
repository: invrs-io/totypes
path: totypes
token: ${{ secrets.INVRSIO_PAT }}

- name: Set up Python
uses: actions/setup-python@v4
with:
Expand All @@ -32,7 +25,6 @@ jobs:
- name: Test pre-commit hooks
run: |
python -m pip install --upgrade pip
pip install "totypes/"
pip install pre-commit
pre-commit run -a
Expand Down Expand Up @@ -66,6 +58,10 @@ jobs:
run: |
darglint src --strictness=short --ignore-raise=ValueError
- name: Validate docstrings
run: |
mypy src
tests:
runs-on: ubuntu-latest
steps:
Expand Down
7 changes: 0 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ repos:
args: [--exit-zero]
exclude: ^tests/

# Type annotation
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
hooks:
- id: mypy
exclude: ^(notebooks|tests)

# Linter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.287
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
"bump-my-version",
"darglint",
"invrs_gym[tests]",
"mypy",
"pre-commit",
]

Expand All @@ -46,10 +47,6 @@ build-backend = "setuptools.build_meta"
line-length = 88
target-version = ['py310']

[tool.mypy]
python_version = "3.10"
strict = true

[tool.isort]
multi_line_output = 3
line_length = 88
Expand Down
13 changes: 7 additions & 6 deletions src/invrs_gym/challenge/ceviche/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import functools
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import agjax
import agjax # type: ignore[import]
import jax
import jax.numpy as jnp
import numpy as onp
from totypes import types # type: ignore[attr-defined]
from totypes import types # type: ignore[import,attr-defined,unused-ignore]

from invrs_gym.challenge.ceviche import defaults
from invrs_gym.loss import transmission_loss
Expand Down Expand Up @@ -85,7 +85,7 @@ def sim_fn(
excite_port_idxs: Sequence[int],
wavelengths_nm: Optional[jnp.ndarray],
max_parallelizm: Optional[int],
) -> Tuple[jnp.ndarray, onp.ndarray]:
) -> Tuple[jnp.ndarray, onp.ndarray[Any, Any]]:
s_params, fields = self.ceviche_model.simulate(
design_variable, excite_port_idxs, wavelengths_nm, max_parallelizm
)
Expand Down Expand Up @@ -227,8 +227,8 @@ def loss(self, response: jnp.ndarray) -> jnp.ndarray:
transmission=transmission,
window_lower_bound=lb,
window_upper_bound=ub,
transmission_exponent=TRANSMISSION_EXPONENT,
scalar_exponent=SCALAR_EXPONENT,
transmission_exponent=jnp.asarray(TRANSMISSION_EXPONENT),
scalar_exponent=jnp.asarray(SCALAR_EXPONENT),
)

def metrics(
Expand Down Expand Up @@ -269,7 +269,8 @@ def _wavelength_bound(
)

repeats = transmission_shape[0] // band_bound.shape[0]
return jnp.repeat(band_bound, repeats, axis=0)
repeated: jnp.ndarray = jnp.repeat(band_bound, repeats, axis=0)
return repeated


# -----------------------------------------------------------------------------
Expand Down
14 changes: 10 additions & 4 deletions src/invrs_gym/challenge/ceviche/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
from typing import Union

import jax.numpy as jnp
from ceviche_challenges import beam_splitter, mode_converter, model_base, params
from ceviche_challenges import ( # type: ignore[import]
beam_splitter,
mode_converter,
model_base,
params,
)
from ceviche_challenges import units as u
from ceviche_challenges import waveguide_bend, wdm
from totypes import symmetry # type: ignore[attr-defined]
from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore]

DeviceSpec = Union[
beam_splitter.spec.BeamSplitterSpec,
Expand All @@ -17,9 +22,10 @@
Model = model_base.Model


def _linear_from_decibels(x_decibels: jnp.ndarray) -> jnp.ndarray:
def _linear_from_decibels(x_decibels: float) -> jnp.ndarray:
"""Converts a quantity in decibels to linear."""
return 10 ** (x_decibels / 10)
linear: jnp.ndarray = jnp.asarray(10 ** (x_decibels / 10))
return linear


WG_WIDTH = 400 * u.nm
Expand Down
58 changes: 40 additions & 18 deletions src/invrs_gym/loss/transmission_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def orthotope_smooth_transmission_loss(
) -> jnp.ndarray:
"""Compute a scalar loss from a array based on an orthotope transmission window.
The loss is related to an orthotope window, i.e. a orthotope target region within
the space that contains valid responses. With values for `transmission_exponent`
and `scalar_exponent` of `0.5` and `2.0`, respectively, this loss function is
equivalent to that of [2022 Schubert](https://arxiv.org/abs/2201.12965).
The loss is related to an orthotope window in tansmission space, i.e. the space
of the squared magnitude of scattering parameters. With values for
`transmission_exponent` and `scalar_exponent` of `1.0` and `2.0`, respectively,
this loss function is equivalent to that of [2022 Schubert]
(https://arxiv.org/abs/2201.12965).
Args:
transmission: The transmission array for which the loss is to be calculated.
Expand All @@ -33,23 +34,28 @@ def orthotope_smooth_transmission_loss(
Returns:
The scalar loss value.
"""
# The signed distance to the target window is positive
elementwise_signed_distance = elementwise_signed_distance_to_window(
# Compute the signed psuedodistance. This is equal to the signed distance to the
# nearest bound, except when the bounds are the min and max physical transmission
# values, in which case the distance is equal to the window size.
elementwise_signed_psuedodistance = elementwise_signed_psuedodistance_to_window(
transmission=transmission**transmission_exponent,
window_lower_bound=window_lower_bound**transmission_exponent,
window_upper_bound=window_upper_bound**transmission_exponent,
)

window_size = jnp.minimum(
window_upper_bound, _MAX_PHYSICAL_TRANSMISSION
) - jnp.maximum(window_lower_bound, _MIN_PHYSICAL_TRANSMISSION)
# Scale the signed distance by the maximum window size.
lower_bound = jnp.maximum(window_lower_bound, _MIN_PHYSICAL_TRANSMISSION)
upper_bound = jnp.minimum(window_upper_bound, _MAX_PHYSICAL_TRANSMISSION)
window_size = upper_bound - lower_bound
elementwise_signed_psuedodistance /= jnp.amin(window_size)

# TODO: check whether dividing by max or min is appropriate.
transformed_elementwise_signed_distance = jax.nn.softplus(
elementwise_signed_distance / jnp.amax(window_size)
elementwise_signed_psuedodistance
)

return jnp.linalg.norm(transformed_elementwise_signed_distance) ** scalar_exponent
loss: jnp.ndarray = (
jnp.linalg.norm(transformed_elementwise_signed_distance) ** scalar_exponent
)
return loss


def distance_to_window(
Expand All @@ -70,29 +76,45 @@ def distance_to_window(
Returns:
The elementwise signed distance.
"""
elementwise_signed_distance = elementwise_signed_distance_to_window(
elementwise_signed_distance = elementwise_signed_psuedodistance_to_window(
transmission=transmission,
window_lower_bound=window_lower_bound,
window_upper_bound=window_upper_bound,
)
elementwise_distance = jnp.maximum(elementwise_signed_distance, 0.0)
return jnp.linalg.norm(elementwise_distance)
distance: jnp.ndarray = jnp.linalg.norm(elementwise_distance)
return distance


def elementwise_signed_distance_to_window(
def elementwise_signed_psuedodistance_to_window(
transmission: jnp.ndarray,
window_lower_bound: jnp.ndarray,
window_upper_bound: jnp.ndarray,
) -> jnp.ndarray:
"""Returns the elementwise signed distance to a transmission window.
"""Returns the elementwise signed psuedodistance to a transmission window.
The psuedodistance is given by,
- the distance to the nearest bound, when both bounds are within the
window (e.g. the lower bound is greater than the minimum physical
transmission value).
- the distance to the upper bound, when the lower bound is less than or
equal to the minimum physical transmission value, and the upper bound
is less than the maximum physical transmission value.
- the distance to the lower bound, when the upper bound is greater than
or equal to the maximum physicall transmission value, and the lower
bound is greater than the minimum physical transmission value.
- the difference between the maximum and minimum physical transmission
value, when both bounds equal or exceed their physical extremal values.
In this case, the psuedodistance has no dependence on `transmission`.
Args:
transmission: The transmission for which the signed distance is sought.
window_lower_bound: Array defining the transmission window lower bound.
window_upper_bound: Array defining the transmission window upper bound.
Returns:
The elementwise signed distance.
The elementwise signed psuedodistance.
"""
# Signed distance to lower bound is positive when `transmission` is below the lower
# bound and negative when it is above the lower bound. When the lower bound is less
Expand Down

0 comments on commit 2d7e4a2

Please sign in to comment.