Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up loss module #5

Merged
merged 6 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4

- name: Checkout submodule
uses: actions/checkout@v4
with:
repository: invrs-io/totypes
path: totypes
token: ${{ secrets.INVRSIO_PAT }}

- name: Set up Python
uses: actions/setup-python@v4
with:
Expand All @@ -32,7 +25,6 @@ jobs:
- name: Test pre-commit hooks
run: |
python -m pip install --upgrade pip
pip install "totypes/"
pip install pre-commit
pre-commit run -a

Expand Down Expand Up @@ -66,6 +58,10 @@ jobs:
run: |
darglint src --strictness=short --ignore-raise=ValueError

- name: Validate docstrings
run: |
mypy src

tests:
runs-on: ubuntu-latest
steps:
Expand Down
7 changes: 0 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ repos:
args: [--exit-zero]
exclude: ^tests/

# Type annotation
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
hooks:
- id: mypy
exclude: ^(notebooks|tests)

# Linter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.287
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
"bump-my-version",
"darglint",
"invrs_gym[tests]",
"mypy",
"pre-commit",
]

Expand All @@ -46,10 +47,6 @@ build-backend = "setuptools.build_meta"
line-length = 88
target-version = ['py310']

[tool.mypy]
python_version = "3.10"
strict = true

[tool.isort]
multi_line_output = 3
line_length = 88
Expand Down
13 changes: 7 additions & 6 deletions src/invrs_gym/challenge/ceviche/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import functools
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import agjax
import agjax # type: ignore[import]
import jax
import jax.numpy as jnp
import numpy as onp
from totypes import types # type: ignore[attr-defined]
from totypes import types # type: ignore[import,attr-defined,unused-ignore]

from invrs_gym.challenge.ceviche import defaults
from invrs_gym.loss import transmission_loss
Expand Down Expand Up @@ -85,7 +85,7 @@ def sim_fn(
excite_port_idxs: Sequence[int],
wavelengths_nm: Optional[jnp.ndarray],
max_parallelizm: Optional[int],
) -> Tuple[jnp.ndarray, onp.ndarray]:
) -> Tuple[jnp.ndarray, onp.ndarray[Any, Any]]:
s_params, fields = self.ceviche_model.simulate(
design_variable, excite_port_idxs, wavelengths_nm, max_parallelizm
)
Expand Down Expand Up @@ -227,8 +227,8 @@ def loss(self, response: jnp.ndarray) -> jnp.ndarray:
transmission=transmission,
window_lower_bound=lb,
window_upper_bound=ub,
transmission_exponent=TRANSMISSION_EXPONENT,
scalar_exponent=SCALAR_EXPONENT,
transmission_exponent=jnp.asarray(TRANSMISSION_EXPONENT),
scalar_exponent=jnp.asarray(SCALAR_EXPONENT),
)

def metrics(
Expand Down Expand Up @@ -269,7 +269,8 @@ def _wavelength_bound(
)

repeats = transmission_shape[0] // band_bound.shape[0]
return jnp.repeat(band_bound, repeats, axis=0)
repeated: jnp.ndarray = jnp.repeat(band_bound, repeats, axis=0)
return repeated


# -----------------------------------------------------------------------------
Expand Down
14 changes: 10 additions & 4 deletions src/invrs_gym/challenge/ceviche/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
from typing import Union

import jax.numpy as jnp
from ceviche_challenges import beam_splitter, mode_converter, model_base, params
from ceviche_challenges import ( # type: ignore[import]
beam_splitter,
mode_converter,
model_base,
params,
)
from ceviche_challenges import units as u
from ceviche_challenges import waveguide_bend, wdm
from totypes import symmetry # type: ignore[attr-defined]
from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore]

DeviceSpec = Union[
beam_splitter.spec.BeamSplitterSpec,
Expand All @@ -17,9 +22,10 @@
Model = model_base.Model


def _linear_from_decibels(x_decibels: jnp.ndarray) -> jnp.ndarray:
def _linear_from_decibels(x_decibels: float) -> jnp.ndarray:
"""Converts a quantity in decibels to linear."""
return 10 ** (x_decibels / 10)
linear: jnp.ndarray = jnp.asarray(10 ** (x_decibels / 10))
return linear


WG_WIDTH = 400 * u.nm
Expand Down
58 changes: 40 additions & 18 deletions src/invrs_gym/loss/transmission_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def orthotope_smooth_transmission_loss(
) -> jnp.ndarray:
"""Compute a scalar loss from a array based on an orthotope transmission window.

The loss is related to an orthotope window, i.e. a orthotope target region within
the space that contains valid responses. With values for `transmission_exponent`
and `scalar_exponent` of `0.5` and `2.0`, respectively, this loss function is
equivalent to that of [2022 Schubert](https://arxiv.org/abs/2201.12965).
The loss is related to an orthotope window in tansmission space, i.e. the space
of the squared magnitude of scattering parameters. With values for
`transmission_exponent` and `scalar_exponent` of `1.0` and `2.0`, respectively,
this loss function is equivalent to that of [2022 Schubert]
(https://arxiv.org/abs/2201.12965).

Args:
transmission: The transmission array for which the loss is to be calculated.
Expand All @@ -33,23 +34,28 @@ def orthotope_smooth_transmission_loss(
Returns:
The scalar loss value.
"""
# The signed distance to the target window is positive
elementwise_signed_distance = elementwise_signed_distance_to_window(
# Compute the signed psuedodistance. This is equal to the signed distance to the
# nearest bound, except when the bounds are the min and max physical transmission
# values, in which case the distance is equal to the window size.
elementwise_signed_psuedodistance = elementwise_signed_psuedodistance_to_window(
transmission=transmission**transmission_exponent,
window_lower_bound=window_lower_bound**transmission_exponent,
window_upper_bound=window_upper_bound**transmission_exponent,
)

window_size = jnp.minimum(
window_upper_bound, _MAX_PHYSICAL_TRANSMISSION
) - jnp.maximum(window_lower_bound, _MIN_PHYSICAL_TRANSMISSION)
# Scale the signed distance by the maximum window size.
lower_bound = jnp.maximum(window_lower_bound, _MIN_PHYSICAL_TRANSMISSION)
upper_bound = jnp.minimum(window_upper_bound, _MAX_PHYSICAL_TRANSMISSION)
window_size = upper_bound - lower_bound
elementwise_signed_psuedodistance /= jnp.amin(window_size)

# TODO: check whether dividing by max or min is appropriate.
transformed_elementwise_signed_distance = jax.nn.softplus(
elementwise_signed_distance / jnp.amax(window_size)
elementwise_signed_psuedodistance
)

return jnp.linalg.norm(transformed_elementwise_signed_distance) ** scalar_exponent
loss: jnp.ndarray = (
jnp.linalg.norm(transformed_elementwise_signed_distance) ** scalar_exponent
)
return loss


def distance_to_window(
Expand All @@ -70,29 +76,45 @@ def distance_to_window(
Returns:
The elementwise signed distance.
"""
elementwise_signed_distance = elementwise_signed_distance_to_window(
elementwise_signed_distance = elementwise_signed_psuedodistance_to_window(
transmission=transmission,
window_lower_bound=window_lower_bound,
window_upper_bound=window_upper_bound,
)
elementwise_distance = jnp.maximum(elementwise_signed_distance, 0.0)
return jnp.linalg.norm(elementwise_distance)
distance: jnp.ndarray = jnp.linalg.norm(elementwise_distance)
return distance


def elementwise_signed_distance_to_window(
def elementwise_signed_psuedodistance_to_window(
transmission: jnp.ndarray,
window_lower_bound: jnp.ndarray,
window_upper_bound: jnp.ndarray,
) -> jnp.ndarray:
"""Returns the elementwise signed distance to a transmission window.
"""Returns the elementwise signed psuedodistance to a transmission window.

The psuedodistance is given by,

- the distance to the nearest bound, when both bounds are within the
window (e.g. the lower bound is greater than the minimum physical
transmission value).
- the distance to the upper bound, when the lower bound is less than or
equal to the minimum physical transmission value, and the upper bound
is less than the maximum physical transmission value.
- the distance to the lower bound, when the upper bound is greater than
or equal to the maximum physicall transmission value, and the lower
bound is greater than the minimum physical transmission value.
- the difference between the maximum and minimum physical transmission
value, when both bounds equal or exceed their physical extremal values.
In this case, the psuedodistance has no dependence on `transmission`.

Args:
transmission: The transmission for which the signed distance is sought.
window_lower_bound: Array defining the transmission window lower bound.
window_upper_bound: Array defining the transmission window upper bound.

Returns:
The elementwise signed distance.
The elementwise signed psuedodistance.
"""
# Signed distance to lower bound is positive when `transmission` is below the lower
# bound and negative when it is above the lower bound. When the lower bound is less
Expand Down